Commit cc6e6b7d authored by Wang, Leping's avatar Wang, Leping
Browse files

- Add config.sh with all pipeline parameters organized by category

  (molecular, crystal structure, compute, run mode, path)
- Refactor search_gen_proc.sh to source config.sh instead of
  hardcoding parameters, with optional config path argument
- Refactor structure_generate.py to load config.sh via exec(),
  replacing hardcoded values with config-driven parameters
- Remove mace-bench (the relaxation part, it will be replaced by updated seperate mace-bench project )
parent 61ec3ad9
#--------------------------- Simulation variables -----------------------------#
# Simulation control parameters.
variable T equal 500
# Simulation steps (t_eq)
variable t_eq equal 100
variable output equal 1 #freq to print output
variable dumpstep equal 1
#------------------------------------------------------------------------------#
#---------------------------- Atomic setup ------------------------------------#
units metal
boundary p p p
# Create atoms.
box tilt large
read_data ./res.dat
replicate 2 2 2
# Define interatomic potential.
pair_style e3gnn/parallel
# The order of element should be the same as the order of elements in the data file (type)
# * * {number of deployed parallel models} {path to deployed parallel models} {elements}
pair_coeff * * 4 ./deployed_parallel Hf O
timestep 0.002
#----------------------------- Run simulation ---------------------------------#
# Setup output
thermo ${output}
thermo_style custom step tpcpu pe ke vol press temp
dump mydump all custom 1 dump.traj id type x y z fx fy fz
dump_modify mydump sort id
fix f1 all nve
fix comfix all momentum 1 linear 1 1 1
velocity all create ${T} 1 dist gaussian mom yes
run 5
#------------------------------------------------------------------------------#
# generated from poscar module (Converted from VASP)
96 atoms
2 atom types
0.00000000 10.12978631 xlo xhi
0.00000000 10.37111894 ylo yhi
0.00000000 10.26314330 zlo zhi
1.73035484 0.00000000 0.00000000 xy xz yz
Masses
1 178.490000
2 16.000000
Atoms
1 1 10.07858846 8.73752520 7.46334159
2 1 9.43739481 3.62108575 2.41757962
3 1 8.95522965 1.01658239 0.30135971
4 1 10.05681289 8.87631272 2.35008881
5 1 9.75638582 6.30304558 0.18303097
6 1 8.91443797 1.06674577 5.54199800
7 1 9.79081409 6.26490552 5.22104118
8 1 4.41792703 3.81788503 2.39528244
9 1 2.17257241 4.11562894 4.98715163
10 1 1.66417958 1.51957159 2.91326352
11 1 5.17465464 8.94063393 2.34565890
12 1 2.85135280 9.31925586 4.96912733
13 1 5.06412333 8.89798173 7.34081071
14 1 2.43171878 6.73610078 7.90086842
15 1 4.79358115 6.33415175 5.30470215
16 1 4.38083973 3.71776573 7.50514775
17 1 2.16740031 4.19218670 10.14985882
18 1 1.72740111 1.53804994 8.11741499
19 1 3.81409500 1.12760478 5.36341140
20 1 2.91124600 9.28502610 10.05473353
21 1 2.43508812 6.79222574 2.84272933
22 1 4.63828767 6.35818290 0.06495312
23 1 3.96862029 1.21993115 0.31149240
24 1 7.95945527 9.30919854 9.99874912
25 1 7.20953615 4.10124686 10.06760490
26 1 7.97413259 9.31020496 4.88564505
27 1 7.48292324 6.67094881 2.71417363
28 1 7.22354435 4.09214379 4.91151391
29 1 6.77110223 1.50790865 2.82107672
30 1 7.60040235 6.67278986 7.89547614
31 1 9.41001699 3.69515647 7.55969566
32 1 6.77374601 1.50559611 7.98091486
33 2 10.09033948 1.85325366 6.96757279
34 2 10.99511263 7.08138493 6.69600728
35 2 11.28967512 9.60939027 6.04904275
36 2 10.78802723 6.95960696 1.75182521
37 2 8.32036165 2.47159112 3.91830180
38 2 10.07211065 1.65854646 1.87162718
39 2 10.48460341 4.35174196 0.83957607
40 2 9.58543540 10.25525675 3.81736961
41 2 9.15581971 7.61785575 3.82383341
42 2 11.28280065 9.42551408 0.89622391
43 2 10.61871694 4.46990764 6.09305074
44 2 3.63623823 5.03045204 3.82116882
45 2 3.22065417 2.54566037 3.96054713
46 2 5.13228413 1.96508633 1.75400085
47 2 0.49972311 0.85950986 4.41955106
48 2 4.49795126 10.25514684 3.94724065
49 2 4.08831630 7.74684459 3.86494489
50 2 5.90127151 7.07173966 1.66530214
51 2 6.32939463 9.71904813 0.93714394
52 2 5.44334962 4.45855194 0.80863483
53 2 3.54898001 7.91810753 6.27738026
54 2 3.21454568 5.33370189 6.39780696
55 2 4.08744202 7.79255573 8.94430580
56 2 5.89610679 7.08975961 6.71713664
57 2 1.22356400 5.98776133 9.31853387
58 2 6.36149555 9.67586631 6.06078686
59 2 3.62496983 5.17884441 8.87967862
60 2 2.79091790 2.70312051 6.50630496
61 2 2.20642369 0.12199325 6.50943225
62 2 3.36642075 2.52335751 9.09596150
63 2 5.04873139 1.85222299 6.87617402
64 2 0.50463414 0.78101050 9.52304998
65 2 0.95275450 3.34556660 8.69553848
66 2 5.46251140 4.47527874 6.00302102
67 2 4.49817894 10.30428629 8.95760259
68 2 1.73952474 8.56368775 8.58555417
69 2 3.62337333 7.98128068 1.17332847
70 2 3.16118634 5.40721734 1.22906859
71 2 1.31418904 5.97276567 4.19215008
72 2 2.79848347 2.74889192 1.27175505
73 2 2.29397766 0.16984188 1.46074130
74 2 0.96707618 3.44188647 3.47135866
75 2 1.75476287 8.54415321 3.43405286
76 2 9.58968370 10.11019896 9.02264863
77 2 9.12010305 7.60109276 8.90873480
78 2 8.21690550 5.34991345 1.37662298
79 2 6.43083748 5.94959474 4.24019201
80 2 8.69454708 4.99651781 3.88811596
81 2 7.96833563 2.69478887 1.20584684
82 2 7.33313963 0.16345962 1.36457040
83 2 5.57830829 0.81411424 4.28350294
84 2 6.12969978 3.45445964 3.35501959
85 2 8.55281287 7.89834015 1.15270373
86 2 6.84047748 8.56744003 3.45602018
87 2 7.33002089 0.08127731 6.53509276
88 2 5.55795729 0.84809385 9.39286244
89 2 8.58051171 7.87084401 6.29048766
90 2 8.28348300 5.22136388 6.47631887
91 2 6.42127905 6.05984919 9.37655897
92 2 6.82699499 8.52585391 8.47852948
93 2 8.68865624 5.09200133 8.99426977
94 2 7.84842127 2.70204335 6.42490863
95 2 8.25796997 2.46457586 9.01312304
96 2 6.08861225 3.44495302 8.57088233
#--------------------------- Simulation variables -----------------------------#
# Simulation control parameters.
variable T equal 500
# Simulation steps (t_eq)
variable t_eq equal 100
variable output equal 1 #freq to print output
variable dumpstep equal 1
#------------------------------------------------------------------------------#
#---------------------------- Atomic setup ------------------------------------#
units metal
boundary p p p
# Create atoms.
box tilt large
read_data ./res.dat
replicate 2 2 2
# Define interatomic potential.
pair_style e3gnn
# The order of element should be the same as the order of elements in the data file (type)
# * * {path to deployed serial model} {elements}
pair_coeff * * ./deployed_serial.pt Hf O
timestep 0.002
#----------------------------- Run simulation ---------------------------------#
# Setup output
thermo ${output} #because it is realaxation
thermo_style custom step tpcpu pe ke vol press temp #record these value (custom setting)
dump mydump all custom 1 dump.traj id type x y z fx fy fz
dump_modify mydump sort id
fix f1 all nve
fix comfix all momentum 1 linear 1 1 1
velocity all create ${T} 1 dist gaussian mom yes
run 5
#------------------------------------------------------------------------------#
# generated from poscar module (Converted from VASP)
96 atoms
2 atom types
0.00000000 10.12978631 xlo xhi
0.00000000 10.37111894 ylo yhi
0.00000000 10.26314330 zlo zhi
1.73035484 0.00000000 0.00000000 xy xz yz
Masses
1 178.490000
2 16.000000
Atoms
1 1 10.07858846 8.73752520 7.46334159
2 1 9.43739481 3.62108575 2.41757962
3 1 8.95522965 1.01658239 0.30135971
4 1 10.05681289 8.87631272 2.35008881
5 1 9.75638582 6.30304558 0.18303097
6 1 8.91443797 1.06674577 5.54199800
7 1 9.79081409 6.26490552 5.22104118
8 1 4.41792703 3.81788503 2.39528244
9 1 2.17257241 4.11562894 4.98715163
10 1 1.66417958 1.51957159 2.91326352
11 1 5.17465464 8.94063393 2.34565890
12 1 2.85135280 9.31925586 4.96912733
13 1 5.06412333 8.89798173 7.34081071
14 1 2.43171878 6.73610078 7.90086842
15 1 4.79358115 6.33415175 5.30470215
16 1 4.38083973 3.71776573 7.50514775
17 1 2.16740031 4.19218670 10.14985882
18 1 1.72740111 1.53804994 8.11741499
19 1 3.81409500 1.12760478 5.36341140
20 1 2.91124600 9.28502610 10.05473353
21 1 2.43508812 6.79222574 2.84272933
22 1 4.63828767 6.35818290 0.06495312
23 1 3.96862029 1.21993115 0.31149240
24 1 7.95945527 9.30919854 9.99874912
25 1 7.20953615 4.10124686 10.06760490
26 1 7.97413259 9.31020496 4.88564505
27 1 7.48292324 6.67094881 2.71417363
28 1 7.22354435 4.09214379 4.91151391
29 1 6.77110223 1.50790865 2.82107672
30 1 7.60040235 6.67278986 7.89547614
31 1 9.41001699 3.69515647 7.55969566
32 1 6.77374601 1.50559611 7.98091486
33 2 10.09033948 1.85325366 6.96757279
34 2 10.99511263 7.08138493 6.69600728
35 2 11.28967512 9.60939027 6.04904275
36 2 10.78802723 6.95960696 1.75182521
37 2 8.32036165 2.47159112 3.91830180
38 2 10.07211065 1.65854646 1.87162718
39 2 10.48460341 4.35174196 0.83957607
40 2 9.58543540 10.25525675 3.81736961
41 2 9.15581971 7.61785575 3.82383341
42 2 11.28280065 9.42551408 0.89622391
43 2 10.61871694 4.46990764 6.09305074
44 2 3.63623823 5.03045204 3.82116882
45 2 3.22065417 2.54566037 3.96054713
46 2 5.13228413 1.96508633 1.75400085
47 2 0.49972311 0.85950986 4.41955106
48 2 4.49795126 10.25514684 3.94724065
49 2 4.08831630 7.74684459 3.86494489
50 2 5.90127151 7.07173966 1.66530214
51 2 6.32939463 9.71904813 0.93714394
52 2 5.44334962 4.45855194 0.80863483
53 2 3.54898001 7.91810753 6.27738026
54 2 3.21454568 5.33370189 6.39780696
55 2 4.08744202 7.79255573 8.94430580
56 2 5.89610679 7.08975961 6.71713664
57 2 1.22356400 5.98776133 9.31853387
58 2 6.36149555 9.67586631 6.06078686
59 2 3.62496983 5.17884441 8.87967862
60 2 2.79091790 2.70312051 6.50630496
61 2 2.20642369 0.12199325 6.50943225
62 2 3.36642075 2.52335751 9.09596150
63 2 5.04873139 1.85222299 6.87617402
64 2 0.50463414 0.78101050 9.52304998
65 2 0.95275450 3.34556660 8.69553848
66 2 5.46251140 4.47527874 6.00302102
67 2 4.49817894 10.30428629 8.95760259
68 2 1.73952474 8.56368775 8.58555417
69 2 3.62337333 7.98128068 1.17332847
70 2 3.16118634 5.40721734 1.22906859
71 2 1.31418904 5.97276567 4.19215008
72 2 2.79848347 2.74889192 1.27175505
73 2 2.29397766 0.16984188 1.46074130
74 2 0.96707618 3.44188647 3.47135866
75 2 1.75476287 8.54415321 3.43405286
76 2 9.58968370 10.11019896 9.02264863
77 2 9.12010305 7.60109276 8.90873480
78 2 8.21690550 5.34991345 1.37662298
79 2 6.43083748 5.94959474 4.24019201
80 2 8.69454708 4.99651781 3.88811596
81 2 7.96833563 2.69478887 1.20584684
82 2 7.33313963 0.16345962 1.36457040
83 2 5.57830829 0.81411424 4.28350294
84 2 6.12969978 3.45445964 3.35501959
85 2 8.55281287 7.89834015 1.15270373
86 2 6.84047748 8.56744003 3.45602018
87 2 7.33002089 0.08127731 6.53509276
88 2 5.55795729 0.84809385 9.39286244
89 2 8.58051171 7.87084401 6.29048766
90 2 8.28348300 5.22136388 6.47631887
91 2 6.42127905 6.05984919 9.37655897
92 2 6.82699499 8.52585391 8.47852948
93 2 8.68865624 5.09200133 8.99426977
94 2 7.84842127 2.70204335 6.42490863
95 2 8.25796997 2.46457586 9.01312304
96 2 6.08861225 3.44495302 8.57088233
# Example input.yaml for training SevenNet.
# '*' signifies default. You can check log.sevenn for defaults.
model:
chemical_species: 'Auto' # Elements model should know. [ 'Univ' | 'Auto' | manual_user_input ]
cutoff: 5.0 # Cutoff radius in Angstroms. If two atoms are within the cutoff, they are connected.
channel: 32 # The multiplicity(channel) of node features.
lmax: 2 # Maximum order of irreducible representations (rotation order).
num_convolution_layer: 3 # The number of message passing layers.
#irreps_manual: # Manually set irreps of the model in each layer
#- "128x0e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e+64x1e+32x2e"
#- "128x0e"
weight_nn_hidden_neurons: [64, 64] # Hidden neurons in convolution weight neural network
radial_basis: # Function and its parameters to encode radial distance
radial_basis_name: 'bessel' # Only 'bessel' is currently supported
bessel_basis_num: 8
cutoff_function: # Envelop function, multiplied to radial_basis functions to init edge features
cutoff_function_name: 'poly_cut' # {'poly_cut' and 'poly_cut_p_value'} or {'XPLOR' and 'cutoff_on'}
poly_cut_p_value: 6
act_gate: {'e': 'silu', 'o': 'tanh'} # Equivalent to 'nonlinearity_gates' in nequip
act_scalar: {'e': 'silu', 'o': 'tanh'} # Equivalent to 'nonlinearity_scalars' in nequip
is_parity: False # Pairy True (E(3) group) or False (to SE(3) group)
self_connection_type: 'nequip' # Default is 'nequip'. 'linear' is used for SevenNet-0. I recommend 'linear' for 'Univ' chemical_species
conv_denominator: "avg_num_neigh" # Valid options are "avg_num_neigh*", "sqrt_avg_num_neigh", or float
train_denominator: False # Enable training for denominator in convolution layer
train_shift_scale: False # Enable training for shift & scale in output layer
train:
random_seed: 1
is_train_stress: True # Includes stress in the loss function
epoch: 200 # Ends training after this number of epochs
#loss: 'Huber' # Default is 'mse' (mean squared error)
#loss_param:
#delta: 0.01
# Each optimizer and scheduler have different available parameters.
# You can refer to sevenn/train/optim.py for supporting optimizer & schedulers
optimizer: 'adam' # Options available are 'sgd', 'adagrad', 'adam', 'adamw', 'radam'
optim_param:
lr: 0.005
scheduler: 'exponentiallr' # 'steplr', 'multisteplr', 'exponentiallr', 'cosineannealinglr', 'reducelronplateau', 'linearlr'
scheduler_param:
gamma: 0.99
force_loss_weight: 0.1 # Coefficient for force loss
stress_loss_weight: 1e-06 # Coefficient for stress loss (to kbar unit)
per_epoch: 10 # Generate checkpoints every this epoch
# ['target y', 'metric']
# Target y: TotalEnergy, Energy, Force, Stress, Stress_GPa, TotalLoss
# Metric : RMSE, MAE, or Loss
error_record:
- ['Energy', 'RMSE']
- ['Force', 'RMSE']
- ['Stress', 'RMSE']
- ['TotalLoss', 'None']
# Continue training model from given checkpoint, or pre-trained model checkpoint for fine-tuning
#continue:
#checkpoint: 'checkpoint_best.pth' # Checkpoint of pre-trained model or a model want to continue training.
#reset_optimizer: False # Set True for fine-tuning
#reset_scheduler: False # Set True for fine-tuning
data:
batch_size: 4 # Per GPU batch size.
data_divide_ratio: 0.1 # Split dataset into training and validation sets by this ratio
shift: 'per_atom_energy_mean' # One of 'per_atom_energy_mean*', 'elemwise_reference_energies', float
scale: 'force_rms' # One of 'force_rms*', 'per_atom_energy_std', float
# SevenNet automatically matches data format from its filename.
# For those not `structure_list` or `.pt` files, assumes it is ASE readable
# In this case, below arguments are directly passed to `ase.io.read`
data_format_args:
index: ':' # see `https://wiki.fysik.dtu.dk/ase/ase/io/io.html` for more valid arguments
# validset is needed if you want '_best.pth' during training. If not, both validset and testset is optional.
load_trainset_path: ['./structure_list'] # Example of using ase as data_format, support multiple files and expansion(*)
#load_validset_path: ['./valid.extxyz']
#load_testset_path: ['./sevenn_data/mydata.pt'] # Graph can be preprocessed using `sevenn_graph_build` and accessible like this
[label_1]
../data/label_1/OUTCAR_{1..5} :
../data/label_1/OUTCAR_{1..5} :
[label_2]
../data/label_2/OUTCAR_{6..10} :
../data/label_2/OUTCAR_{6..10} :
[project]
name = "sevenn"
version = "0.11.1.dev3"
authors = [
{ name = "Yutack Park", email = "parkyutack@snu.ac.kr" },
{ name = "Haekwan Jeon", email = "haekwan98@snu.ac.kr" },
{ name = "Jaesun Kim" },
{ name = "Gijin Kim" },
{ name = "Hyungmin An" },
]
description = "Scalable EquiVariance Enabled Neural Network"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: POSIX :: Linux",
]
dependencies = [
"ase",
"braceexpand",
"pyyaml",
"e3nn>=0.5.0",
"tqdm",
"scikit-learn",
"torch_geometric>=2.5.0",
"numpy",
"matscipy",
"pandas",
"requests",
"setuptools>=61.0"
]
[project.optional-dependencies]
test = ["pytest", "pytest-cov>=5"]
cueq12 = ["cuequivariance>=0.4.0; python_version >= '3.10'", "cuequivariance-torch>=0.4.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu12; python_version >= '3.10'"]
cueq11 = ["cuequivariance>=0.4.0; python_version >= '3.10'", "cuequivariance-torch>=0.4.0; python_version >= '3.10'", "cuequivariance-ops-torch-cu11; python_version >= '3.10'"]
[project.scripts]
sevenn = "sevenn.main.sevenn:main"
sevenn_get_model = "sevenn.main.sevenn_get_model:main"
sevenn_graph_build = "sevenn.main.sevenn_graph_build:main"
sevenn_inference = "sevenn.main.sevenn_inference:main"
sevenn_patch_lammps = "sevenn.main.sevenn_patch_lammps:main"
sevenn_preset = "sevenn.main.sevenn_preset:main"
sevenn_cp = "sevenn.main.sevenn_cp:main"
[project.urls]
Homepage = "https://github.com/MDIL-SNU/SevenNet"
Issues = "https://github.com/MDIL-SNU/SevenNet/issues"
[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=61.0"]
[tool.setuptools.package-data]
sevenn = [
"logo_ascii",
"*.so",
"pair_e3gnn/*.cpp",
"pair_e3gnn/*.h",
"pair_e3gnn/*.cu",
"pair_e3gnn/patch_lammps.sh",
"presets/*.yaml",
"pretrained_potentials/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth",
"pretrained_potentials/SevenNet_0__22May2024/checkpoint_sevennet_0.pth",
"pretrained_potentials/SevenNet_l3i5/checkpoint_l3i5.pth",
"pretrained_potentials/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth",
"py.typed",
]
[tool.setuptools.packages.find]
include = ["sevenn*"]
exclude = ["tests*", "example_inputs*", ]
[tool.pytest.ini_options]
log_cli = true
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
[tool.ruff]
line-length = 85
[tool.ruff.lint]
extend-select = ["E501"]
[tool.ruff.format]
quote-style = "single"
docstring-code-format = true
[flake8]
max-line-length = 85
max-complexity = 12
select = C,E,F,W,B,B950
ignore = F401, W503, W605, E741, E203, C901, E722
[isort]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=80
known_third_party=ase,braceexpand,e3nn,numpy,packaging,pandas,pytest,requests,sklearn,torch,torch_geometric,tqdm,yaml
known_first_party=
This diff is collapsed.
LICENSE
README.md
pyproject.toml
setup.cfg
sevenn/__init__.py
sevenn/_const.py
sevenn/_keys.py
sevenn/atom_graph_data.py
sevenn/calculator.py
sevenn/checkpoint.py
sevenn/error_recorder.py
sevenn/logger.py
sevenn/logo_ascii
sevenn/model_build.py
sevenn/pair_d3.so
sevenn/parse_input.py
sevenn/py.typed
sevenn/sevenn_logger.py
sevenn/sevennet_calculator.py
sevenn/util.py
sevenn.egg-info/PKG-INFO
sevenn.egg-info/SOURCES.txt
sevenn.egg-info/dependency_links.txt
sevenn.egg-info/entry_points.txt
sevenn.egg-info/requires.txt
sevenn.egg-info/top_level.txt
sevenn/main/__init__.py
sevenn/main/sevenn.py
sevenn/main/sevenn_cp.py
sevenn/main/sevenn_get_model.py
sevenn/main/sevenn_graph_build.py
sevenn/main/sevenn_inference.py
sevenn/main/sevenn_patch_lammps.py
sevenn/main/sevenn_preset.py
sevenn/nn/__init__.py
sevenn/nn/activation.py
sevenn/nn/convolution.py
sevenn/nn/cue_helper.py
sevenn/nn/edge_embedding.py
sevenn/nn/equivariant_gate.py
sevenn/nn/force_output.py
sevenn/nn/interaction_blocks.py
sevenn/nn/linear.py
sevenn/nn/node_embedding.py
sevenn/nn/scale.py
sevenn/nn/self_connection.py
sevenn/nn/sequential.py
sevenn/nn/util.py
sevenn/pair_e3gnn/comm_brick.cpp
sevenn/pair_e3gnn/comm_brick.h
sevenn/pair_e3gnn/pair_d3.cu
sevenn/pair_e3gnn/pair_d3.h
sevenn/pair_e3gnn/pair_d3_for_ase.cu
sevenn/pair_e3gnn/pair_d3_for_ase.h
sevenn/pair_e3gnn/pair_d3_pars.h
sevenn/pair_e3gnn/pair_e3gnn.cpp
sevenn/pair_e3gnn/pair_e3gnn.h
sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp
sevenn/pair_e3gnn/pair_e3gnn_parallel.h
sevenn/pair_e3gnn/patch_lammps.sh
sevenn/presets/MF_0.yaml
sevenn/presets/base.yaml
sevenn/presets/fine_tune.yaml
sevenn/presets/fine_tune_le.yaml
sevenn/presets/multi_modal.yaml
sevenn/presets/sevennet-0.yaml
sevenn/presets/sevennet-l3i5.yaml
sevenn/scripts/__init__.py
sevenn/scripts/backward_compatibility.py
sevenn/scripts/convert_model_modality.py
sevenn/scripts/deploy.py
sevenn/scripts/graph_build.py
sevenn/scripts/inference.py
sevenn/scripts/processing_continue.py
sevenn/scripts/processing_dataset.py
sevenn/scripts/processing_epoch.py
sevenn/scripts/train.py
sevenn/train/__init__.py
sevenn/train/atoms_dataset.py
sevenn/train/collate.py
sevenn/train/dataload.py
sevenn/train/dataset.py
sevenn/train/graph_dataset.py
sevenn/train/loss.py
sevenn/train/modal_dataset.py
sevenn/train/optim.py
sevenn/train/trainer.py
\ No newline at end of file
[console_scripts]
sevenn = sevenn.main.sevenn:main
sevenn_cp = sevenn.main.sevenn_cp:main
sevenn_get_model = sevenn.main.sevenn_get_model:main
sevenn_graph_build = sevenn.main.sevenn_graph_build:main
sevenn_inference = sevenn.main.sevenn_inference:main
sevenn_patch_lammps = sevenn.main.sevenn_patch_lammps:main
sevenn_preset = sevenn.main.sevenn_preset:main
ase
braceexpand
pyyaml
e3nn>=0.5.0
tqdm
scikit-learn
torch_geometric>=2.5.0
numpy
matscipy
pandas
requests
setuptools>=61.0
[cueq11]
[cueq11:python_version >= "3.10"]
cuequivariance>=0.4.0
cuequivariance-torch>=0.4.0
cuequivariance-ops-torch-cu11
[cueq12]
[cueq12:python_version >= "3.10"]
cuequivariance>=0.4.0
cuequivariance-torch>=0.4.0
cuequivariance-ops-torch-cu12
[test]
pytest
pytest-cov>=5
from importlib.metadata import version
from packaging.version import Version
__version__ = version('sevenn')
from e3nn import __version__ as e3nn_ver
if Version(e3nn_ver) < Version('0.5.0'):
raise ValueError(
'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient '
'convention.'
)
import os
from enum import Enum
from typing import Dict
import torch
import sevenn._keys as KEY
from sevenn.nn.activation import ShiftedSoftPlus
NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118
IMPLEMENTED_RADIAL_BASIS = ['bessel']
IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR']
# TODO: support None. This became difficult because of parallel model
IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear']
IMPLEMENTED_INTERACTION_TYPE = ['nequip']
IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies']
IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms']
SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss']
SUPPORTING_ERROR_TYPES = [
'TotalEnergy',
'Energy',
'Force',
'Stress',
'Stress_GPa',
'TotalLoss',
]
IMPLEMENTED_MODEL = ['E3_equivariant_model']
# string input to real torch function
ACTIVATION = {
'relu': torch.nn.functional.relu,
'silu': torch.nn.functional.silu,
'tanh': torch.tanh,
'abs': torch.abs,
'ssp': ShiftedSoftPlus,
'sigmoid': torch.sigmoid,
'elu': torch.nn.functional.elu,
}
ACTIVATION_FOR_EVEN = {
'ssp': ShiftedSoftPlus,
'silu': torch.nn.functional.silu,
}
ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs}
ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD}
_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials')
SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth'
SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth'
SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth'
SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth'
_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download'
CHECKPOINT_DOWNLOAD_LINKS = {
SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth',
SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth',
}
# to avoid torch script to compile torch_geometry.data
AtomGraphDataType = Dict[str, torch.Tensor]
class LossType(Enum): # only used for train_v1, do not use it afterwards
ENERGY = 'energy' # eV or eV/atom
FORCE = 'force' # eV/A
STRESS = 'stress' # kB
def error_record_condition(x):
if type(x) is not list:
return False
for v in x:
if type(v) is not list or len(v) != 2:
return False
if v[0] not in SUPPORTING_ERROR_TYPES:
return False
if v[0] == 'TotalLoss':
continue
if v[1] not in SUPPORTING_METRICS:
return False
return True
DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = {
KEY.CUTOFF: 4.5,
KEY.NODE_FEATURE_MULTIPLICITY: 32,
KEY.IRREPS_MANUAL: False,
KEY.LMAX: 1,
KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax
KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax
KEY.IS_PARITY: True,
KEY.NUM_CONVOLUTION: 3,
KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: 'bessel',
},
KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: 'poly_cut',
},
KEY.ACTIVATION_RADIAL: 'silu',
KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'},
KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'},
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64],
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY.CONV_DENOMINATOR: 'avg_num_neigh',
KEY.TRAIN_DENOMINTAOR: False,
KEY.TRAIN_SHIFT_SCALE: False,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.USE_BIAS_IN_LINEAR: False,
KEY.USE_MODAL_NODE_EMBEDDING: False,
KEY.USE_MODAL_SELF_INTER_INTRO: False,
KEY.USE_MODAL_SELF_INTER_OUTRO: False,
KEY.USE_MODAL_OUTPUT_BLOCK: False,
KEY.READOUT_AS_FCN: False,
# Applied af readout as fcn is True
KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30],
KEY.READOUT_FCN_ACTIVATION: 'relu',
KEY.SELF_CONNECTION_TYPE: 'nequip',
KEY.INTERACTION_TYPE: 'nequip',
KEY._NORMALIZE_SPH: True,
KEY.CUEQUIVARIANCE_CONFIG: {},
}
# Basically, "If provided, it should be type of ..."
MODEL_CONFIG_CONDITION = {
KEY.NODE_FEATURE_MULTIPLICITY: int,
KEY.LMAX: int,
KEY.LMAX_EDGE: int,
KEY.LMAX_NODE: int,
KEY.IS_PARITY: bool,
KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS,
},
KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION,
},
KEY.CUTOFF: float,
KEY.NUM_CONVOLUTION: int,
KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float)
or x
in [
'avg_num_neigh',
'sqrt_avg_num_neigh',
],
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list,
KEY.TRAIN_SHIFT_SCALE: bool,
KEY.TRAIN_DENOMINTAOR: bool,
KEY.USE_BIAS_IN_LINEAR: bool,
KEY.USE_MODAL_NODE_EMBEDDING: bool,
KEY.USE_MODAL_SELF_INTER_INTRO: bool,
KEY.USE_MODAL_SELF_INTER_OUTRO: bool,
KEY.USE_MODAL_OUTPUT_BLOCK: bool,
KEY.READOUT_AS_FCN: bool,
KEY.READOUT_FCN_HIDDEN_NEURONS: list,
KEY.READOUT_FCN_ACTIVATION: str,
KEY.ACTIVATION_RADIAL: str,
KEY.SELF_CONNECTION_TYPE: lambda x: (
x in IMPLEMENTED_SELF_CONNECTION_TYPE
or (
isinstance(x, list)
and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x)
)
),
KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE,
KEY._NORMALIZE_SPH: bool,
KEY.CUEQUIVARIANCE_CONFIG: dict,
}
def model_defaults(config):
defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG
if KEY.READOUT_AS_FCN not in config:
config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN]
if config[KEY.READOUT_AS_FCN] is False:
defaults.pop(KEY.READOUT_FCN_ACTIVATION, None)
defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None)
return defaults
DEFAULT_DATA_CONFIG = {
KEY.DTYPE: 'single',
KEY.DATA_FORMAT: 'ase',
KEY.DATA_FORMAT_ARGS: {},
KEY.SAVE_DATASET: False,
KEY.SAVE_BY_LABEL: False,
KEY.SAVE_BY_TRAIN_VALID: False,
KEY.RATIO: 0.0,
KEY.BATCH_SIZE: 6,
KEY.PREPROCESS_NUM_CORES: 1,
KEY.COMPUTE_STATISTICS: True,
KEY.DATASET_TYPE: 'graph',
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.USE_MODAL_WISE_SHIFT: False,
KEY.USE_MODAL_WISE_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms',
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
}
DATA_CONFIG_CONDITION = {
KEY.DTYPE: str,
KEY.DATA_FORMAT: str,
KEY.DATA_FORMAT_ARGS: dict,
KEY.SAVE_DATASET: str,
KEY.SAVE_BY_LABEL: bool,
KEY.SAVE_BY_TRAIN_VALID: bool,
KEY.RATIO: float,
KEY.BATCH_SIZE: int,
KEY.PREPROCESS_NUM_CORES: int,
KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'],
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT,
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.USE_MODAL_WISE_SHIFT: bool,
KEY.USE_MODAL_WISE_SCALE: bool,
# KEY.DATA_SHUFFLE: bool,
KEY.COMPUTE_STATISTICS: bool,
# KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool,
}
def data_defaults(config):
defaults = DEFAULT_DATA_CONFIG
if KEY.LOAD_VALIDSET in config:
defaults.pop(KEY.RATIO, None)
return defaults
DEFAULT_TRAINING_CONFIG = {
KEY.RANDOM_SEED: 1,
KEY.EPOCH: 300,
KEY.LOSS: 'mse',
KEY.LOSS_PARAM: {},
KEY.OPTIMIZER: 'adam',
KEY.OPTIM_PARAM: {},
KEY.SCHEDULER: 'exponentiallr',
KEY.SCHEDULER_PARAM: {},
KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.PER_EPOCH: 5,
# KEY.USE_TESTSET: False,
KEY.CONTINUE: {
KEY.CHECKPOINT: False,
KEY.RESET_OPTIMIZER: False,
KEY.RESET_SCHEDULER: False,
KEY.RESET_EPOCH: False,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True,
},
# KEY.DEFAULT_MODAL: 'common',
KEY.CSV_LOG: 'log.csv',
KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True,
KEY.TRAIN_SHUFFLE: True,
KEY.ERROR_RECORD: [
['Energy', 'RMSE'],
['Force', 'RMSE'],
['Stress', 'RMSE'],
['TotalLoss', 'None'],
],
KEY.BEST_METRIC: 'TotalLoss',
KEY.USE_WEIGHT: False,
KEY.USE_MODALITY: False,
}
TRAINING_CONFIG_CONDITION = {
KEY.RANDOM_SEED: int,
KEY.EPOCH: int,
KEY.FORCE_WEIGHT: float,
KEY.STRESS_WEIGHT: float,
KEY.USE_TESTSET: None, # Not used
KEY.NUM_WORKERS: int,
KEY.PER_EPOCH: int,
KEY.CONTINUE: {
KEY.CHECKPOINT: str,
KEY.RESET_OPTIMIZER: bool,
KEY.RESET_SCHEDULER: bool,
KEY.RESET_EPOCH: bool,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool,
},
KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool,
KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str,
KEY.CSV_LOG: str,
KEY.USE_MODALITY: bool,
KEY.USE_WEIGHT: bool,
}
def train_defaults(config):
defaults = DEFAULT_TRAINING_CONFIG
if KEY.IS_TRAIN_STRESS not in config:
config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS]
if not config[KEY.IS_TRAIN_STRESS]:
defaults.pop(KEY.STRESS_WEIGHT, None)
return defaults
"""
How to add new feature?
1. Add new key to this file.
2. Add new key to _const.py
2.1. if the type of input is consistent,
write adequate condition and default to _const.py.
2.2. if the type of input is not consistent,
you must add your own input validation code to
parse_input.py
"""
from typing import Final
# see
# https://github.com/pytorch/pytorch/issues/52312
# for FYI
# ~~ keys ~~ #
# PyG : primitive key of torch_geometric.data.Data type
# ==================================================#
# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ #
# ==================================================#
# some raw properties of graph
ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N)
POS: Final[str] = 'pos' # (N, 3) PyG
CELL: Final[str] = 'cell_lattice_vectors' # (3, 3)
CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3)
CELL_VOLUME: Final[str] = 'cell_volume'
EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3)
EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1)
# some primary data of graph
EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG
ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes
NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG
NODE_FEATURE_GHOST: Final[str] = 'x_ghost'
NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot
MODAL_ATTR: Final[str] = (
'modal_attr' # (1, N_modalities) for handling multi-modal
)
MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal
EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics)
EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding)
# inputs of loss function
ENERGY: Final[str] = 'total_energy' # (1)
FORCE: Final[str] = 'force_of_atoms' # (N, 3)
STRESS: Final[str] = 'stress' # (6)
# This is for training, per atom scale.
SCALED_ENERGY: Final[str] = 'scaled_total_energy'
# general outputs of models
SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy'
ATOMIC_ENERGY: Final[str] = 'atomic_energy'
PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy'
PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy'
PER_ATOM_ENERGY: Final[str] = 'per_atom_energy'
PRED_FORCE: Final[str] = 'inferred_force'
SCALED_FORCE: Final[str] = 'scaled_force'
PRED_STRESS: Final[str] = 'inferred_stress'
SCALED_STRESS: Final[str] = 'scaled_stress'
# very general data property for AtomGraphData
NUM_ATOMS: Final[str] = 'num_atoms' # int
NUM_GHOSTS: Final[str] = 'num_ghosts'
NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu
USER_LABEL: Final[str] = 'user_label'
DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data
DATA_MODALITY: Final[str] = (
'data_modality' # modality of given data. e.g. PBE and SCAN
)
BATCH: Final[str] = 'batch'
TAG = 'tag' # replace USER_LABEL
# etc
SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp'
BATCH_SIZE: Final[str] = 'batch_size'
INFO: Final[str] = 'data_info'
# something special
LABEL_NONE: Final[str] = 'No_label'
# ==================================================#
# ~~~~~~ KEY for train/data configuration ~~~~~~~~ #
# ==================================================#
PREPROCESS_NUM_CORES = 'preprocess_num_cores'
SAVE_DATASET = 'save_dataset_path'
SAVE_BY_LABEL = 'save_by_label'
SAVE_BY_TRAIN_VALID = 'save_by_train_valid'
DATA_FORMAT = 'data_format'
DATA_FORMAT_ARGS = 'data_format_args'
STRUCTURE_LIST = 'structure_list'
LOAD_DATASET = 'load_dataset_path' # not used in v2
LOAD_TRAINSET = 'load_trainset_path'
LOAD_VALIDSET = 'load_validset_path'
LOAD_TESTSET = 'load_testset_path'
FORMAT_OUTPUTS = 'format_outputs_for_ase'
COMPUTE_STATISTICS = 'compute_statistics'
DATASET_TYPE = 'dataset_type'
RANDOM_SEED = 'random_seed'
RATIO = 'data_divide_ratio'
USE_TESTSET = 'use_testset'
EPOCH = 'epoch'
LOSS = 'loss'
LOSS_PARAM = 'loss_param'
OPTIMIZER = 'optimizer'
OPTIM_PARAM = 'optim_param'
SCHEDULER = 'scheduler'
SCHEDULER_PARAM = 'scheduler_param'
FORCE_WEIGHT = 'force_loss_weight'
STRESS_WEIGHT = 'stress_loss_weight'
DEVICE = 'device'
DTYPE = 'dtype'
TRAIN_SHUFFLE = 'train_shuffle'
IS_TRAIN_STRESS = 'is_train_stress'
CONTINUE = 'continue'
CHECKPOINT = 'checkpoint'
RESET_OPTIMIZER = 'reset_optimizer'
RESET_SCHEDULER = 'reset_scheduler'
RESET_EPOCH = 'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = (
'use_statistic_values_for_cp_modal_only'
)
CSV_LOG = 'csv_log'
ERROR_RECORD = 'error_record'
BEST_METRIC = 'best_metric'
NUM_WORKERS = 'num_workers' # not work
RANK = 'rank'
LOCAL_RANK = 'local_rank'
WORLD_SIZE = 'world_size'
IS_DDP = 'is_ddp'
DDP_BACKEND = 'ddp_backend'
PER_EPOCH = 'per_epoch'
USE_WEIGHT = 'use_weight'
USE_MODALITY = 'use_modality'
DEFAULT_MODAL = 'default_modal'
# ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
# ==================================================#
# ~~ global model configuration ~~ #
# note that these names are directly used for input.yaml for user input
MODEL_TYPE = '_model_type'
CUTOFF = 'cutoff'
CHEMICAL_SPECIES = 'chemical_species'
MODAL_LIST = 'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number'
NUM_SPECIES = '_number_of_species'
NUM_MODALITIES = '_number_of_modalities'
TYPE_MAP = '_type_map'
MODAL_MAP = '_modal_map'
# ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type
IRREPS_MANUAL = 'irreps_manual'
NODE_FEATURE_MULTIPLICITY = 'channel'
RADIAL_BASIS = 'radial_basis'
BESSEL_BASIS_NUM = 'bessel_basis_num'
CUTOFF_FUNCTION = 'cutoff_function'
POLY_CUT_P = 'poly_cut_p_value'
LMAX = 'lmax'
LMAX_EDGE = 'lmax_edge'
LMAX_NODE = 'lmax_node'
IS_PARITY = 'is_parity'
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons'
NUM_CONVOLUTION = 'num_convolution_layer'
ACTIVATION_SCARLAR = 'act_scalar'
ACTIVATION_GATE = 'act_gate'
ACTIVATION_RADIAL = 'act_radial'
SELF_CONNECTION_TYPE = 'self_connection_type'
RADIAL_BASIS_NAME = 'radial_basis_name'
CUTOFF_FUNCTION_NAME = 'cutoff_function_name'
USE_BIAS_IN_LINEAR = 'use_bias_in_linear'
USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block'
READOUT_AS_FCN = 'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION = 'readout_fcn_activation'
AVG_NUM_NEIGH = 'avg_num_neigh'
CONV_DENOMINATOR = 'conv_denominator'
SHIFT = 'shift'
SCALE = 'scale'
USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift'
USE_MODAL_WISE_SCALE = 'use_modal_wise_scale'
TRAIN_SHIFT_SCALE = 'train_shift_scale'
TRAIN_DENOMINTAOR = 'train_denominator'
INTERACTION_TYPE = 'interaction_type'
TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated
CUEQUIVARIANCE_CONFIG = 'cuequivariance_config'
_NORMALIZE_SPH = '_normalize_sph'
OPTIMIZE_BY_REDUCE = 'optimize_by_reduce'
from typing import Optional
import torch
import torch_geometric.data
import sevenn._keys as KEY
import sevenn.util
class AtomGraphData(torch_geometric.data.Data):
"""
Args:
x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes,
atomic_numbers]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in coordinate
format with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y_energy: scalar # unit of eV (VASP raw)
y_force: [num_nodes, 3] # unit of eV/A (VASP raw)
y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
x, y_force, pos should be aligned with each other.
"""
def __init__(
self,
x: Optional[torch.Tensor] = None,
edge_index: Optional[torch.Tensor] = None,
pos: Optional[torch.Tensor] = None,
edge_attr: Optional[torch.Tensor] = None,
**kwargs
):
super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos)
self[KEY.NODE_ATTR] = x # ?
for k, v in kwargs.items():
self[k] = v
def to_numpy_dict(self):
# This is not debugged yet!
dct = {
k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v
for k, v in self.items()
}
return dct
def fit_dimension(self):
per_atom_keys = [
KEY.ATOMIC_NUMBERS,
KEY.ATOMIC_ENERGY,
KEY.POS,
KEY.FORCE,
KEY.PRED_FORCE,
]
natoms = self.num_atoms.item()
for k, v in self.items():
if not isinstance(v, torch.Tensor):
continue
if natoms == 1 and k in per_atom_keys:
self[k] = v.squeeze().unsqueeze(0)
else:
self[k] = v.squeeze()
return self
@staticmethod
def from_numpy_dict(dct):
for k, v in dct.items():
if k == KEY.CELL_SHIFT:
dct[k] = torch.Tensor(v) # this is special
else:
dct[k] = sevenn.util.dtype_correct(v)
return AtomGraphData(**dct)
ninja_required_version = 1.3
cxx = c++
nvcc = /usr/local/cuda/bin/nvcc
cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17
post_cflags =
cuda_cflags = -DTORCH_EXTENSION_NAME=pair_d3 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/TH -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/mazhaojia/pkg/miniconda3/envs/7net-cueq/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -O3 --expt-relaxed-constexpr -fmad=false -std=c++17
cuda_post_cflags =
cuda_dlink_post_cflags =
ldflags = -shared -L/home/mazhaojia/pkg/miniconda3/envs/7net-cueq/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
depfile = $out.d
deps = gcc
command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
rule link
command = $cxx $in $ldflags -o $out
build pair_d3_for_ase.cuda.o: cuda_compile /home/mazhaojia/mace-project/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/pair_d3_for_ase.cu
build pair_d3.so: link pair_d3_for_ase.cuda.o
default pair_d3.so
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment