Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
nivren
ICT-CSP
Commits
ca86f720
You need to sign in or sign up before continuing.
Unverified
Commit
ca86f720
authored
Aug 24, 2025
by
zcxzcx1
Committed by
GitHub
Aug 24, 2025
Browse files
Add files via upload
parent
b75ed73c
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4648 additions
and
0 deletions
+4648
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py
...rdparty/SevenNet/sevenn/scripts/convert_model_modality.py
+301
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py
mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py
+148
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py
mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py
+119
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py
mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py
+227
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py
...h/3rdparty/SevenNet/sevenn/scripts/processing_continue.py
+273
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py
...ch/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py
+481
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py
...ench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py
+182
-0
mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py
mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py
+139
-0
mace-bench/3rdparty/SevenNet/sevenn/train/__init__.py
mace-bench/3rdparty/SevenNet/sevenn/train/__init__.py
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc
...evenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc
...evenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc
...SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc
+0
-0
mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py
mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py
+314
-0
mace-bench/3rdparty/SevenNet/sevenn/train/collate.py
mace-bench/3rdparty/SevenNet/sevenn/train/collate.py
+41
-0
mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py
mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py
+609
-0
mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py
mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py
+496
-0
mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py
mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py
+707
-0
mace-bench/3rdparty/SevenNet/sevenn/train/loss.py
mace-bench/3rdparty/SevenNet/sevenn/train/loss.py
+223
-0
mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py
mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py
+365
-0
mace-bench/3rdparty/SevenNet/sevenn/train/optim.py
mace-bench/3rdparty/SevenNet/sevenn/train/optim.py
+23
-0
No files found.
mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py
0 → 100644
View file @
ca86f720
import
math
from
typing
import
List
import
torch
import
torch.nn
as
nn
from
e3nn.o3
import
Irreps
,
Linear
import
sevenn._keys
as
KEY
from
sevenn.model_build
import
build_E3_equivariant_model
modal_module_dict
=
{
KEY
.
USE_MODAL_NODE_EMBEDDING
:
'onehot_to_feature_x'
,
KEY
.
USE_MODAL_SELF_INTER_INTRO
:
'self_interaction_1'
,
KEY
.
USE_MODAL_SELF_INTER_OUTRO
:
'self_interaction_2'
,
KEY
.
USE_MODAL_OUTPUT_BLOCK
:
'reduce_input_to_hidden'
,
}
def
_get_scalar_index
(
irreps
:
Irreps
):
scalar_indices
=
[]
for
idx
,
(
_
,
(
l
,
p
))
in
enumerate
(
irreps
):
# noqa
if
(
l
==
0
and
p
==
1
):
# get index of parameter for scalar (0e), which is used for modality
scalar_indices
.
append
(
idx
)
return
scalar_indices
def
_reshape_weight_of_linear
(
irreps_in
:
Irreps
,
irreps_out
:
Irreps
,
weight
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
linear
=
Linear
(
irreps_in
,
irreps_out
)
linear
.
weight
=
nn
.
Parameter
(
weight
)
return
list
(
linear
.
weight_views
())
def
_erase_linear_modal_params
(
model_state_dct
:
dict
,
erase_modal_indices
:
List
[
int
],
key
:
str
,
irreps_in
:
Irreps
,
irreps_out
:
Irreps
,
):
orig_input_dim
=
irreps_in
.
count
(
'0e'
)
new_input_dim
=
orig_input_dim
-
len
(
erase_modal_indices
)
orig_weight
=
model_state_dct
[
key
+
'.linear.weight'
]
scalar_idx
=
_get_scalar_index
(
irreps_in
)
linear_weight_list
=
_reshape_weight_of_linear
(
irreps_in
,
irreps_out
,
orig_weight
)
new_weight_list
=
[]
for
idx
,
l_p_weight
in
enumerate
(
linear_weight_list
[:
-
1
]):
new_weight
=
torch
.
reshape
(
l_p_weight
,
(
1
,
-
1
)).
squeeze
()
if
idx
in
scalar_idx
:
new_weight
=
new_weight
*
math
.
sqrt
(
new_input_dim
/
orig_input_dim
)
new_weight_list
.
append
(
new_weight
)
"""
Following works for normalization = `path`, which is not used in SEVENNet
for l_p_weight in linear_weight_list[:-1]:
new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze())
"""
flattened_weight
=
torch
.
cat
(
new_weight_list
)
return
flattened_weight
def
_get_modal_weight_as_bias
(
model_state_dct
:
dict
,
key
:
str
,
ref_index
:
int
,
irreps_in
:
Irreps
,
irreps_out
:
Irreps
,
):
assert
ref_index
!=
-
1
input_dim
=
irreps_in
.
count
(
'0e'
)
output_dim
=
irreps_out
.
count
(
'0e'
)
orig_weight
=
model_state_dct
[
key
+
'.linear.weight'
]
orig_bias
=
model_state_dct
[
key
+
'.linear.bias'
]
if
len
(
orig_bias
)
==
0
:
orig_bias
=
torch
.
zeros
(
output_dim
,
dtype
=
orig_weight
.
dtype
)
modal_weight
=
_reshape_weight_of_linear
(
irreps_in
,
irreps_out
,
orig_weight
)[
-
1
]
new_bias
=
orig_bias
+
modal_weight
[
ref_index
]
/
math
.
sqrt
(
input_dim
)
return
new_bias
def
_append_modal_weight
(
model_state_dct
:
dict
,
# state dict to be targeted
key
:
str
,
# linear weight modune name
irreps_in
:
Irreps
,
# irreps_in before modality append
irreps_out
:
Irreps
,
append_number
:
int
,
):
# This works for normalization = `element`, default in SEVENNet.
# (normalization = `path` is curruently deprecated in SEVENNet.)
input_dim
=
irreps_in
.
count
(
'0e'
)
output_dim
=
irreps_out
.
count
(
'0e'
)
new_input_dim
=
input_dim
+
append_number
orig_weight
=
model_state_dct
[
key
+
'.linear.weight'
]
scalar_idx
=
_get_scalar_index
(
irreps_in
)
linear_weight_list
=
_reshape_weight_of_linear
(
irreps_in
,
irreps_out
,
orig_weight
)
new_weight_list
=
[]
# TODO: combine following as function with _erase_linear_modal_params
for
idx
,
l_p_weight
in
enumerate
(
linear_weight_list
):
new_weight
=
torch
.
reshape
(
l_p_weight
,
(
1
,
-
1
)).
squeeze
()
if
idx
in
scalar_idx
:
new_weight
=
new_weight
*
math
.
sqrt
(
new_input_dim
/
input_dim
)
new_weight_list
.
append
(
new_weight
)
flattened_weight_list
=
[]
for
l_p_weight
in
new_weight_list
:
flattened_weight_list
.
append
(
torch
.
reshape
(
l_p_weight
,
(
1
,
-
1
)).
squeeze
()
)
flattened_weight
=
torch
.
cat
(
flattened_weight_list
)
append_weight
=
torch
.
cat
([
flattened_weight
,
torch
.
zeros
(
append_number
*
output_dim
,
dtype
=
flattened_weight
.
dtype
),
])
# zeros: starting from common model
return
append_weight
def
get_single_modal_model_dct
(
model_state_dct
:
dict
,
config
:
dict
,
ref_modal
:
str
,
from_processing_cp
:
bool
=
False
,
is_deploy
:
bool
=
False
,
):
"""
Convert multimodal model state dictionary to single modal model.
Modal is selected by `ref_modal`
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
`ref_modal`: modal that are going to be converted
`from_processing_cp`: if True, use modal_map of the checkpoint file
`is_deploy`: if True, model is build with single-modal shift and scale
"""
if
(
not
from_processing_cp
and
not
config
[
KEY
.
USE_MODALITY
]
):
# model is already single modal
return
model_state_dct
config
[
KEY
.
USE_BIAS_IN_LINEAR
]
=
True
config
[
'_deploy'
]
=
is_deploy
model
=
build_E3_equivariant_model
(
config
)
del
config
[
'_deploy'
]
key_add
=
'_cp'
if
from_processing_cp
else
''
modal_type_dict
=
config
[
KEY
.
MODAL_MAP
+
key_add
]
erase_modal_indices
=
range
(
len
(
modal_type_dict
.
keys
()))
# starts with 0
if
ref_modal
!=
'common'
:
try
:
ref_modal_index
=
modal_type_dict
[
ref_modal
]
except
:
raise
KeyError
(
f
'
{
ref_modal
}
not in modal type. Use one of'
f
'
{
modal_type_dict
.
keys
()
}
.'
)
for
module_key
in
model
.
_modules
.
keys
():
for
(
use_modal_module_key
,
modal_module_name
,
)
in
modal_module_dict
.
items
():
irreps_out
=
Irreps
(
model
.
get_irreps_in
(
module_key
,
'irreps_out'
))
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
output_dim
=
irreps_out
.
count
(
'0e'
)
if
(
config
[
use_modal_module_key
]
and
modal_module_name
in
module_key
):
# this module is used for giving modality
irreps_in
=
Irreps
(
model
.
get_irreps_in
(
module_key
,
'irreps_in'
)
)
new_bias
=
(
torch
.
zeros
(
output_dim
)
if
ref_modal
==
'common'
else
_get_modal_weight_as_bias
(
model_state_dct
,
module_key
,
ref_modal_index
,
irreps_in
,
# type: ignore
irreps_out
,
# type: ignore
)
)
erased_modal_weight
=
_erase_linear_modal_params
(
model_state_dct
,
erase_modal_indices
,
module_key
,
irreps_in
,
# type: ignore
irreps_out
,
# type: ignore
)
model_state_dct
[
module_key
+
'.linear.weight'
]
=
(
erased_modal_weight
)
model_state_dct
[
module_key
+
'.linear.bias'
]
=
new_bias
elif
modal_module_name
in
module_key
:
model_state_dct
[
module_key
+
'.linear.bias'
]
=
torch
.
zeros
(
output_dim
,
dtype
=
model_state_dct
[
module_key
+
'.linear.weight'
].
dtype
,
)
final_block_key
=
'reduce_hidden_to_energy'
model_state_dct
[
final_block_key
+
'.linear.bias'
]
=
torch
.
tensor
(
[
0
],
dtype
=
model_state_dct
[
final_block_key
+
'.linear.weight'
].
dtype
)
if
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
or
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]:
rescaler_names
=
[]
if
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]:
rescaler_names
.
append
(
'shift'
)
if
config
[
KEY
.
USE_MODAL_WISE_SCALE
]:
rescaler_names
.
append
(
'scale'
)
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
=
False
config
[
KEY
.
USE_MODAL_WISE_SCALE
]
=
False
for
rescaler_name
in
rescaler_names
:
rescaler_key
=
'rescale_atomic_energy.'
+
rescaler_name
rescaler
=
model_state_dct
[
rescaler_key
][
ref_modal_index
]
model_state_dct
.
update
({
rescaler_key
:
rescaler
})
config
.
update
({
rescaler_name
:
rescaler
})
config
[
KEY
.
USE_MODALITY
]
=
False
return
model_state_dct
def
append_modality_to_model_dct
(
model_state_dct
:
dict
,
config
:
dict
,
orig_num_modal
:
int
,
append_modal_length
:
int
,
):
"""
Append modal-wise parameters to the original linear layers.
This enables expanding modal to single/multi modal model checkpoint.
`model_state_dct`: model state dictionary from multimodal checkpoint file
`config`: dictionary containing configuration of the checkpoint model
+ modality appended
`orig_num_modal`: Number of modality used in original checkpoint
`append_modal_length`: Number of modality to be appended in new checkpoint.
"""
config_num_modal
=
config
[
KEY
.
NUM_MODALITIES
]
config
.
update
({
KEY
.
NUM_MODALITIES
:
orig_num_modal
,
KEY
.
USE_MODALITY
:
True
})
model
=
build_E3_equivariant_model
(
config
)
for
module_key
in
model
.
_modules
.
keys
():
for
(
use_modal_module_key
,
modal_module_name
,
)
in
modal_module_dict
.
items
():
if
(
config
[
use_modal_module_key
]
and
modal_module_name
in
module_key
):
# this module is used for giving modality
irreps_in
=
model
.
get_irreps_in
(
module_key
,
'irreps_in'
)
# TODO: directly using "irreps_in" might not be compatible
# when changing `nn/linear.py`
irreps_out
=
model
.
get_irreps_in
(
module_key
,
'irreps_out'
)
irreps_in
,
irreps_out
=
Irreps
(
irreps_in
),
Irreps
(
irreps_out
)
append_weight
=
_append_modal_weight
(
model_state_dct
,
module_key
,
irreps_in
,
# type: ignore
irreps_out
,
# type: ignore
append_modal_length
,
)
model_state_dct
[
module_key
+
'.linear.weight'
]
=
append_weight
config
[
KEY
.
NUM_MODALITIES
]
=
config_num_modal
return
model_state_dct
mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py
0 → 100644
View file @
ca86f720
import
os
from
datetime
import
datetime
from
typing
import
Optional
import
e3nn.util.jit
import
torch
import
torch.nn
from
ase.data
import
chemical_symbols
import
sevenn._keys
as
KEY
from
sevenn
import
__version__
from
sevenn.model_build
import
build_E3_equivariant_model
from
sevenn.util
import
load_checkpoint
def
deploy
(
checkpoint
,
fname
=
'deployed_serial.pt'
,
modal
:
Optional
[
str
]
=
None
):
"""
This method is messy to avoid changes in pair_e3gnn.cpp, while
refactoring python part.
If changes the behavior, and accordingly pair_e3gnn.cpp,
we have to recompile LAMMPS (which I always want to procrastinate)
"""
from
sevenn.nn.edge_embedding
import
EdgePreprocess
from
sevenn.nn.force_output
import
ForceStressOutput
cp
=
load_checkpoint
(
checkpoint
)
model
,
config
=
cp
.
build_model
(
'e3nn'
),
cp
.
config
model
.
prepand_module
(
'edge_preprocess'
,
EdgePreprocess
(
True
))
grad_module
=
ForceStressOutput
()
model
.
replace_module
(
'force_output'
,
grad_module
)
new_grad_key
=
grad_module
.
get_grad_key
()
model
.
key_grad
=
new_grad_key
if
hasattr
(
model
,
'eval_type_map'
):
setattr
(
model
,
'eval_type_map'
,
False
)
if
modal
:
model
.
prepare_modal_deploy
(
modal
)
elif
model
.
modal_map
is
not
None
and
len
(
model
.
modal_map
)
>=
1
:
raise
ValueError
(
f
'Modal is not given. It has:
{
list
(
model
.
modal_map
.
keys
())
}
'
)
model
.
set_is_batch_data
(
False
)
model
.
eval
()
model
=
e3nn
.
util
.
jit
.
script
(
model
)
model
=
torch
.
jit
.
freeze
(
model
)
# make some config need for md
md_configs
=
{}
type_map
=
config
[
KEY
.
TYPE_MAP
]
chem_list
=
''
for
Z
in
type_map
.
keys
():
chem_list
+=
chemical_symbols
[
Z
]
+
' '
chem_list
.
strip
()
md_configs
.
update
({
'chemical_symbols_to_index'
:
chem_list
})
md_configs
.
update
({
'cutoff'
:
str
(
config
[
KEY
.
CUTOFF
])})
md_configs
.
update
({
'num_species'
:
str
(
config
[
KEY
.
NUM_SPECIES
])})
md_configs
.
update
(
{
'model_type'
:
config
.
pop
(
KEY
.
MODEL_TYPE
,
'E3_equivariant_model'
)}
)
md_configs
.
update
({
'version'
:
__version__
})
md_configs
.
update
({
'dtype'
:
config
.
pop
(
KEY
.
DTYPE
,
'single'
)})
md_configs
.
update
({
'time'
:
datetime
.
now
().
strftime
(
'%Y-%m-%d'
)})
if
fname
.
endswith
(
'.pt'
)
is
False
:
fname
+=
'.pt'
torch
.
jit
.
save
(
model
,
fname
,
_extra_files
=
md_configs
)
# TODO: build model only once
def
deploy_parallel
(
checkpoint
,
fname
=
'deployed_parallel'
,
modal
:
Optional
[
str
]
=
None
):
# Additional layer for ghost atom (and copy parameters from original)
GHOST_LAYERS_KEYS
=
[
'onehot_to_feature_x'
,
'0_self_interaction_1'
]
cp
=
load_checkpoint
(
checkpoint
)
model
,
config
=
cp
.
build_model
(
'e3nn'
),
cp
.
config
config
[
KEY
.
CUEQUIVARIANCE_CONFIG
]
=
{
'use'
:
False
}
model_state_dct
=
model
.
state_dict
()
model_list
=
build_E3_equivariant_model
(
config
,
parallel
=
True
)
dct_temp
=
{}
copy_counter
=
{
gk
:
0
for
gk
in
GHOST_LAYERS_KEYS
}
for
ghost_layer_key
in
GHOST_LAYERS_KEYS
:
for
key
,
val
in
model_state_dct
.
items
():
if
not
key
.
startswith
(
ghost_layer_key
):
continue
dct_temp
.
update
({
f
'ghost_
{
key
}
'
:
val
})
copy_counter
[
ghost_layer_key
]
+=
1
# Ensure reference weights are copied from state dict
assert
all
(
x
>
0
for
x
in
copy_counter
.
values
())
model_state_dct
.
update
(
dct_temp
)
for
model_part
in
model_list
:
missing
,
_
=
model_part
.
load_state_dict
(
model_state_dct
,
strict
=
False
)
if
hasattr
(
model_part
,
'eval_type_map'
):
setattr
(
model_part
,
'eval_type_map'
,
False
)
# Ensure all values are inserted
assert
len
(
missing
)
==
0
,
missing
if
modal
:
model_list
[
0
].
prepare_modal_deploy
(
modal
)
elif
model_list
[
0
].
modal_map
is
not
None
:
raise
ValueError
(
f
'Modal is not given. It has:
{
list
(
model_list
[
0
].
modal_map
.
keys
())
}
'
)
# prepare some extra information for MD
md_configs
=
{}
type_map
=
config
[
KEY
.
TYPE_MAP
]
chem_list
=
''
for
Z
in
type_map
.
keys
():
chem_list
+=
chemical_symbols
[
Z
]
+
' '
chem_list
.
strip
()
comm_size
=
max
(
[
seg
.
_modules
[
f
'
{
t
}
_convolution'
].
_comm_size
# type: ignore
for
t
,
seg
in
enumerate
(
model_list
)
]
)
md_configs
.
update
({
'chemical_symbols_to_index'
:
chem_list
})
md_configs
.
update
({
'cutoff'
:
str
(
config
[
KEY
.
CUTOFF
])})
md_configs
.
update
({
'num_species'
:
str
(
config
[
KEY
.
NUM_SPECIES
])})
md_configs
.
update
({
'comm_size'
:
str
(
comm_size
)})
md_configs
.
update
(
{
'model_type'
:
config
.
pop
(
KEY
.
MODEL_TYPE
,
'E3_equivariant_model'
)}
)
md_configs
.
update
({
'version'
:
__version__
})
md_configs
.
update
({
'dtype'
:
config
.
pop
(
KEY
.
DTYPE
,
'single'
)})
md_configs
.
update
({
'time'
:
datetime
.
now
().
strftime
(
'%Y-%m-%d'
)})
os
.
makedirs
(
fname
)
for
idx
,
model
in
enumerate
(
model_list
):
fname_full
=
f
'
{
fname
}
/deployed_parallel_
{
idx
}
.pt'
model
.
set_is_batch_data
(
False
)
model
.
eval
()
model
=
e3nn
.
util
.
jit
.
script
(
model
)
model
=
torch
.
jit
.
freeze
(
model
)
torch
.
jit
.
save
(
model
,
fname_full
,
_extra_files
=
md_configs
)
mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py
0 → 100644
View file @
ca86f720
import
os
from
typing
import
List
,
Optional
from
sevenn.logger
import
Logger
from
sevenn.train.dataset
import
AtomGraphDataset
from
sevenn.util
import
unique_filepath
def
build_sevennet_graph_dataset
(
source
:
List
[
str
],
cutoff
:
float
,
num_cores
:
int
,
out
:
str
,
filename
:
str
,
metadata
:
Optional
[
dict
]
=
None
,
**
fmt_kwargs
,
):
from
sevenn.train.graph_dataset
import
SevenNetGraphDataset
log
=
Logger
()
if
metadata
is
None
:
metadata
=
{}
log
.
timer_start
(
'graph_build'
)
db
=
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
out
,
files
=
source
,
processed_name
=
filename
,
process_num_cores
=
num_cores
,
**
fmt_kwargs
,
)
log
.
timer_end
(
'graph_build'
,
'graph build time'
)
log
.
writeline
(
f
'Graph saved:
{
db
.
processed_paths
[
0
]
}
'
)
log
.
bar
()
for
k
,
v
in
metadata
.
items
():
log
.
format_k_v
(
k
,
v
,
write
=
True
)
log
.
bar
()
log
.
writeline
(
'Distribution:'
)
log
.
statistic_write
(
db
.
statistics
)
log
.
format_k_v
(
'# atoms (node)'
,
db
.
natoms
,
write
=
True
)
log
.
format_k_v
(
'# structures (graph)'
,
len
(
db
),
write
=
True
)
def
dataset_finalize
(
dataset
,
metadata
,
out
):
"""
Deprecated
"""
natoms
=
dataset
.
get_natoms
()
species
=
dataset
.
get_species
()
metadata
=
{
**
metadata
,
'natoms'
:
natoms
,
'species'
:
species
,
}
dataset
.
meta
=
metadata
if
os
.
path
.
isdir
(
out
):
out
=
os
.
path
.
join
(
out
,
'graph_built.sevenn_data'
)
elif
out
.
endswith
(
'.sevenn_data'
)
is
False
:
out
=
out
+
'.sevenn_data'
out
=
unique_filepath
(
out
)
log
=
Logger
()
log
.
writeline
(
'The metadata of the dataset is...'
)
for
k
,
v
in
metadata
.
items
():
log
.
format_k_v
(
k
,
v
,
write
=
True
)
dataset
.
save
(
out
)
log
.
writeline
(
f
'dataset is saved to
{
out
}
'
)
return
dataset
def
build_script
(
source
:
List
[
str
],
cutoff
:
float
,
num_cores
:
int
,
out
:
str
,
metadata
:
Optional
[
dict
]
=
None
,
**
fmt_kwargs
,
):
"""
Deprecated
"""
from
sevenn.train.dataload
import
file_to_dataset
,
match_reader
if
metadata
is
None
:
metadata
=
{}
log
=
Logger
()
dataset
=
AtomGraphDataset
({},
cutoff
)
common_args
=
{
'cutoff'
:
cutoff
,
'cores'
:
num_cores
,
'label'
:
'graph_build'
,
}
log
.
timer_start
(
'graph_build'
)
for
path
in
source
:
if
os
.
path
.
isdir
(
path
):
continue
log
.
writeline
(
f
'Read:
{
path
}
'
)
basename
=
os
.
path
.
basename
(
path
)
if
'structure_list'
in
basename
:
fmt
=
'structure_list'
else
:
fmt
=
'ase'
reader
,
rmeta
=
match_reader
(
fmt
,
**
fmt_kwargs
)
metadata
.
update
(
**
rmeta
)
dataset
.
augment
(
file_to_dataset
(
file
=
path
,
reader
=
reader
,
**
common_args
,
)
)
log
.
timer_end
(
'graph_build'
,
'graph build time'
)
dataset_finalize
(
dataset
,
metadata
,
out
)
mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py
0 → 100644
View file @
ca86f720
import
csv
import
os
from
typing
import
Iterable
,
List
,
Optional
,
Union
import
numpy
as
np
from
torch_geometric.loader
import
DataLoader
from
tqdm
import
tqdm
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.train.graph_dataset
import
SevenNetGraphDataset
from
sevenn.train.modal_dataset
import
SevenNetMultiModalDataset
def
write_inference_csv
(
output_list
,
out
):
for
i
,
output
in
enumerate
(
output_list
):
output
=
output
.
fit_dimension
()
output
[
KEY
.
STRESS
]
=
output
[
KEY
.
STRESS
]
*
1602.1766208
output
[
KEY
.
PRED_STRESS
]
=
output
[
KEY
.
PRED_STRESS
]
*
1602.1766208
output_list
[
i
]
=
output
.
to_numpy_dict
()
per_graph_keys
=
[
KEY
.
NUM_ATOMS
,
KEY
.
USER_LABEL
,
KEY
.
ENERGY
,
KEY
.
PRED_TOTAL_ENERGY
,
KEY
.
STRESS
,
KEY
.
PRED_STRESS
,
]
per_atom_keys
=
[
KEY
.
ATOMIC_NUMBERS
,
KEY
.
ATOMIC_ENERGY
,
KEY
.
POS
,
KEY
.
FORCE
,
KEY
.
PRED_FORCE
,
]
def
unfold_dct_val
(
dct
,
keys
,
suffix_list
=
None
):
res
=
{}
if
suffix_list
is
None
:
suffix_list
=
range
(
100
)
for
k
in
keys
:
if
k
not
in
dct
:
res
[
k
]
=
'-'
elif
isinstance
(
dct
[
k
],
np
.
ndarray
)
and
dct
[
k
].
ndim
!=
0
:
res
.
update
(
{
f
'
{
k
}
_
{
suffix_list
[
i
]
}
'
:
v
for
i
,
v
in
enumerate
(
dct
[
k
])}
)
else
:
res
[
k
]
=
dct
[
k
]
return
res
def
per_atom_dct_list
(
dct
,
keys
):
sfx_list
=
[
'x'
,
'y'
,
'z'
]
res
=
[]
natoms
=
dct
[
KEY
.
NUM_ATOMS
]
extracted
=
{
k
:
dct
[
k
]
for
k
in
keys
}
for
i
in
range
(
natoms
):
raw
=
{}
raw
.
update
({
k
:
v
[
i
]
for
k
,
v
in
extracted
.
items
()})
per_atom_dct
=
unfold_dct_val
(
raw
,
keys
,
suffix_list
=
sfx_list
)
res
.
append
(
per_atom_dct
)
return
res
try
:
with
open
(
f
'
{
out
}
/info.csv'
,
'w'
,
newline
=
''
)
as
f
:
header
=
output_list
[
0
][
KEY
.
INFO
].
keys
()
writer
=
csv
.
DictWriter
(
f
,
fieldnames
=
header
)
writer
.
writeheader
()
for
output
in
output_list
:
writer
.
writerow
(
output
[
KEY
.
INFO
])
except
(
KeyError
,
TypeError
,
AttributeError
,
csv
.
Error
)
as
e
:
print
(
e
)
print
(
'failed to write meta data, info.csv is not written'
)
with
open
(
f
'
{
out
}
/per_graph.csv'
,
'w'
,
newline
=
''
)
as
f
:
sfx_list
=
[
'xx'
,
'yy'
,
'zz'
,
'xy'
,
'yz'
,
'zx'
]
# for stress
writer
=
None
for
output
in
output_list
:
cell_dct
=
{
KEY
.
CELL
:
output
[
KEY
.
CELL
]}
cell_dct
=
unfold_dct_val
(
cell_dct
,
[
KEY
.
CELL
],
[
'a'
,
'b'
,
'c'
])
data
=
{
**
unfold_dct_val
(
output
,
per_graph_keys
,
sfx_list
),
**
cell_dct
,
}
if
writer
is
None
:
writer
=
csv
.
DictWriter
(
f
,
fieldnames
=
data
.
keys
())
writer
.
writeheader
()
writer
.
writerow
(
data
)
with
open
(
f
'
{
out
}
/per_atom.csv'
,
'w'
,
newline
=
''
)
as
f
:
writer
=
None
for
i
,
output
in
enumerate
(
output_list
):
list_of_dct
=
per_atom_dct_list
(
output
,
per_atom_keys
)
for
j
,
dct
in
enumerate
(
list_of_dct
):
idx_dct
=
{
'stct_id'
:
i
,
'atom_id'
:
j
}
data
=
{
**
idx_dct
,
**
dct
}
if
writer
is
None
:
writer
=
csv
.
DictWriter
(
f
,
fieldnames
=
data
.
keys
())
writer
.
writeheader
()
writer
.
writerow
(
data
)
def
_patch_data_info
(
graph_list
:
Iterable
[
AtomGraphData
],
full_file_list
:
List
[
str
]
)
->
None
:
keys
=
set
()
for
graph
,
path
in
zip
(
graph_list
,
full_file_list
):
if
KEY
.
INFO
not
in
graph
:
graph
[
KEY
.
INFO
]
=
{}
graph
[
KEY
.
INFO
].
update
({
'file'
:
os
.
path
.
abspath
(
path
)})
keys
.
update
(
graph
[
KEY
.
INFO
].
keys
())
# save only safe subset of info (for batching)
for
graph
in
graph_list
:
info_dict
=
graph
[
KEY
.
INFO
]
info_dict
.
update
({
k
:
''
for
k
in
keys
if
k
not
in
info_dict
})
def
inference
(
checkpoint
:
str
,
targets
:
Union
[
str
,
List
[
str
]],
output_dir
:
str
,
num_workers
:
int
=
1
,
device
:
str
=
'cpu'
,
batch_size
:
int
=
4
,
save_graph
:
bool
=
False
,
allow_unlabeled
:
bool
=
False
,
modal
:
Optional
[
str
]
=
None
,
**
data_kwargs
,
)
->
None
:
"""
Inference model on the target dataset, writes
per_graph, per_atom inference results in csv format
to the output_dir
If a given target doesn't have EFS key, it puts dummy
values.
Args:
checkpoint: model checkpoint path,
target: path, or list of path to evaluate. Supports
ASE readable, sevenn_data/*.pt, .sevenn_data, and
structure_list
output_dir: directory to write results
num_workers: number of workers to build graph
device: device to evaluate, defaults to 'auto'
batch_size: batch size for inference
save_grpah: if True, save preprocessed graph to output dir
data_kwargs: keyword arguments used when reading targets,
for example, given index='-1', only the last snapshot
will be evaluated if it was ASE readable.
While this function can handle different types of targets
at once, it will not work smoothly with data_kwargs
"""
model
,
_
=
util
.
model_from_checkpoint
(
checkpoint
)
cutoff
=
model
.
cutoff
if
modal
:
if
model
.
modal_map
is
None
:
raise
ValueError
(
'Modality given, but model has no modal_map'
)
if
modal
not
in
model
.
modal_map
:
_modals
=
list
(
model
.
modal_map
.
keys
())
raise
ValueError
(
f
'Unknown modal
{
modal
}
(not in
{
_modals
}
)'
)
if
isinstance
(
targets
,
str
):
targets
=
[
targets
]
full_file_list
=
[]
if
save_graph
:
dataset
=
SevenNetGraphDataset
(
cutoff
=
cutoff
,
root
=
output_dir
,
files
=
targets
,
process_num_cores
=
num_workers
,
processed_name
=
'saved_graph.pt'
,
**
data_kwargs
,
)
full_file_list
=
dataset
.
full_file_list
# TODO: not used currently
else
:
dataset
=
[]
for
file
in
targets
:
tmplist
=
SevenNetGraphDataset
.
file_to_graph_list
(
file
,
cutoff
=
cutoff
,
num_cores
=
num_workers
,
allow_unlabeled
=
allow_unlabeled
,
**
data_kwargs
,
)
dataset
.
extend
(
tmplist
)
full_file_list
.
extend
([
os
.
path
.
abspath
(
file
)]
*
len
(
tmplist
))
if
(
full_file_list
is
not
None
and
len
(
full_file_list
)
==
len
(
dataset
)
and
not
isinstance
(
dataset
,
SevenNetGraphDataset
)
):
_patch_data_info
(
dataset
,
full_file_list
)
# type: ignore
if
modal
:
dataset
=
SevenNetMultiModalDataset
({
modal
:
dataset
})
# type: ignore
loader
=
DataLoader
(
dataset
,
batch_size
,
shuffle
=
False
)
# type: ignore
model
.
to
(
device
)
model
.
set_is_batch_data
(
True
)
model
.
eval
()
rec
=
util
.
get_error_recorder
()
output_list
=
[]
for
batch
in
tqdm
(
loader
):
batch
=
batch
.
to
(
device
)
output
=
model
(
batch
).
detach
().
cpu
()
rec
.
update
(
output
)
output_list
.
extend
(
util
.
to_atom_graph_list
(
output
))
errors
=
rec
.
epoch_forward
()
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
with
open
(
os
.
path
.
join
(
output_dir
,
'errors.txt'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
key
,
val
in
errors
.
items
():
f
.
write
(
f
'
{
key
}
:
{
val
}
\n
'
)
write_inference_csv
(
output_list
,
output_dir
)
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py
0 → 100644
View file @
ca86f720
import
os
import
warnings
import
torch
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
sevenn.logger
import
Logger
from
sevenn.scripts.convert_model_modality
import
(
append_modality_to_model_dct
,
get_single_modal_model_dct
,
)
def
processing_continue_v2
(
config
):
# simpler
"""
Replacement of processing_continue,
Skips model compatibility
"""
log
=
Logger
()
continue_dct
=
config
[
KEY
.
CONTINUE
]
log
.
write
(
'
\n
Continue found, loading checkpoint
\n
'
)
checkpoint
=
util
.
load_checkpoint
(
continue_dct
[
KEY
.
CHECKPOINT
])
model_cp
=
checkpoint
.
build_model
()
config_cp
=
checkpoint
.
config
model_state_dict_cp
=
model_cp
.
state_dict
()
optimizer_state_dict_cp
=
(
checkpoint
.
optimizer_state_dict
if
not
continue_dct
[
KEY
.
RESET_OPTIMIZER
]
else
None
)
scheduler_state_dict_cp
=
(
checkpoint
.
scheduler_state_dict
if
not
continue_dct
[
KEY
.
RESET_SCHEDULER
]
else
None
)
# use_statistic_value_of_checkpoint always True
# Overwrite config from model state dict, so graph_dataset.from_config
# will not put statistic values to shift, scale, and conv_denominator
config
[
KEY
.
SHIFT
]
=
model_state_dict_cp
[
'rescale_atomic_energy.shift'
].
tolist
()
config
[
KEY
.
SCALE
]
=
model_state_dict_cp
[
'rescale_atomic_energy.scale'
].
tolist
()
conv_denom
=
[]
for
i
in
range
(
config_cp
[
KEY
.
NUM_CONVOLUTION
]):
conv_denom
.
append
(
model_state_dict_cp
[
f
'
{
i
}
_convolution.denominator'
].
item
())
config
[
KEY
.
CONV_DENOMINATOR
]
=
conv_denom
log
.
writeline
(
f
'
{
KEY
.
SHIFT
}
,
{
KEY
.
SCALE
}
, and
{
KEY
.
CONV_DENOMINATOR
}
are '
+
'overwritten by model_state_dict of checkpoint'
)
chem_keys
=
[
KEY
.
TYPE_MAP
,
KEY
.
NUM_SPECIES
,
KEY
.
CHEMICAL_SPECIES
,
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
,
]
config
.
update
({
k
:
config_cp
[
k
]
for
k
in
chem_keys
})
log
.
writeline
(
'chemical_species are overwritten by checkpoint. '
+
f
'This model knows
{
config
[
KEY
.
NUM_SPECIES
]
}
species'
)
if
config_cp
.
get
(
KEY
.
USE_MODALITY
,
False
)
!=
config
.
get
(
KEY
.
USE_MODALITY
):
raise
ValueError
(
'use_modality is not same. Check sevenn_cp'
)
modal_map
=
config_cp
.
get
(
KEY
.
MODAL_MAP
,
None
)
# dict | None
if
modal_map
and
len
(
modal_map
)
>
0
:
modalities
=
list
(
modal_map
.
keys
())
log
.
writeline
(
f
'Multimodal model found:
{
modalities
}
'
)
log
.
writeline
(
'use_modality: True'
)
config
[
KEY
.
USE_MODALITY
]
=
True
from_epoch
=
checkpoint
.
epoch
or
0
log
.
writeline
(
f
'Checkpoint previous epoch was:
{
from_epoch
}
'
)
epoch
=
1
if
continue_dct
[
KEY
.
RESET_EPOCH
]
else
from_epoch
+
1
log
.
writeline
(
f
'epoch start from
{
epoch
}
'
)
log
.
writeline
(
'checkpoint loading successful'
)
state_dicts
=
[
model_state_dict_cp
,
optimizer_state_dict_cp
,
scheduler_state_dict_cp
,
]
return
state_dicts
,
epoch
def
check_config_compatible
(
config
,
config_cp
):
# TODO: check more
SHOULD_BE_SAME
=
[
KEY
.
NODE_FEATURE_MULTIPLICITY
,
KEY
.
LMAX
,
KEY
.
IS_PARITY
,
KEY
.
RADIAL_BASIS
,
KEY
.
CUTOFF_FUNCTION
,
KEY
.
CUTOFF
,
KEY
.
CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS
,
KEY
.
NUM_CONVOLUTION
,
KEY
.
USE_BIAS_IN_LINEAR
,
KEY
.
SELF_CONNECTION_TYPE
,
]
for
sbs
in
SHOULD_BE_SAME
:
if
config
[
sbs
]
==
config_cp
[
sbs
]:
continue
if
sbs
==
KEY
.
SELF_CONNECTION_TYPE
and
config_cp
[
sbs
]
==
'MACE'
:
warnings
.
warn
(
'We do not support this version of checkpoints to continue '
"Please use self_connection_type='linear' in input.yaml "
'and train from scratch'
,
UserWarning
,
)
raise
ValueError
(
f
'Value of
{
sbs
}
should be same.
{
config
[
sbs
]
}
!=
{
config_cp
[
sbs
]
}
'
)
try
:
cntdct
=
config
[
KEY
.
CONTINUE
]
except
KeyError
:
return
TRAINABLE_CONFIGS
=
[
KEY
.
TRAIN_DENOMINTAOR
,
KEY
.
TRAIN_SHIFT_SCALE
]
if
(
any
((
not
cntdct
[
KEY
.
RESET_SCHEDULER
],
not
cntdct
[
KEY
.
RESET_OPTIMIZER
]))
and
all
(
config
[
k
]
==
config_cp
[
k
]
for
k
in
TRAINABLE_CONFIGS
)
is
False
):
raise
ValueError
(
'reset optimizer and scheduler if you want to change '
+
'trainable configs'
)
# TODO add conition for changed optim/scheduler but not reset
def
processing_continue
(
config
):
log
=
Logger
()
continue_dct
=
config
[
KEY
.
CONTINUE
]
log
.
write
(
'
\n
Continue found, loading checkpoint
\n
'
)
checkpoint
=
torch
.
load
(
continue_dct
[
KEY
.
CHECKPOINT
],
map_location
=
'cpu'
,
weights_only
=
False
)
config_cp
=
checkpoint
[
'config'
]
model_cp
,
config_cp
=
util
.
model_from_checkpoint
(
checkpoint
)
model_state_dict_cp
=
model_cp
.
state_dict
()
# it will raise error if not compatible
check_config_compatible
(
config
,
config_cp
)
log
.
write
(
'Checkpoint config is compatible
\n
'
)
# for backward compat.
config
.
update
({
KEY
.
_NORMALIZE_SPH
:
config_cp
[
KEY
.
_NORMALIZE_SPH
]})
from_epoch
=
checkpoint
[
'epoch'
]
optimizer_state_dict_cp
=
(
checkpoint
[
'optimizer_state_dict'
]
if
not
continue_dct
[
KEY
.
RESET_OPTIMIZER
]
else
None
)
scheduler_state_dict_cp
=
(
checkpoint
[
'scheduler_state_dict'
]
if
not
continue_dct
[
KEY
.
RESET_SCHEDULER
]
else
None
)
# These could be changed based on given continue_input.yaml
# ex) adapt to statistics of fine-tuning dataset
shift_cp
=
model_state_dict_cp
[
'rescale_atomic_energy.shift'
].
numpy
()
del
model_state_dict_cp
[
'rescale_atomic_energy.shift'
]
scale_cp
=
model_state_dict_cp
[
'rescale_atomic_energy.scale'
].
numpy
()
del
model_state_dict_cp
[
'rescale_atomic_energy.scale'
]
conv_denominators
=
[]
for
i
in
range
(
config_cp
[
KEY
.
NUM_CONVOLUTION
]):
conv_denominators
.
append
(
(
model_state_dict_cp
[
f
'
{
i
}
_convolution.denominator'
]).
item
()
)
del
model_state_dict_cp
[
f
'
{
i
}
_convolution.denominator'
]
# Further handled by processing_dataset.py
config
.
update
({
KEY
.
SHIFT
+
'_cp'
:
shift_cp
,
KEY
.
SCALE
+
'_cp'
:
scale_cp
,
KEY
.
CONV_DENOMINATOR
+
'_cp'
:
conv_denominators
,
})
chem_keys
=
[
KEY
.
TYPE_MAP
,
KEY
.
NUM_SPECIES
,
KEY
.
CHEMICAL_SPECIES
,
KEY
.
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER
,
]
config
.
update
({
k
:
config_cp
[
k
]
for
k
in
chem_keys
})
if
(
KEY
.
USE_MODALITY
in
config_cp
.
keys
()
and
config_cp
[
KEY
.
USE_MODALITY
]
):
# checkpoint model is multimodal
config
.
update
({
KEY
.
MODAL_MAP
+
'_cp'
:
config_cp
[
KEY
.
MODAL_MAP
],
KEY
.
USE_MODALITY
+
'_cp'
:
True
,
KEY
.
NUM_MODALITIES
+
'_cp'
:
len
(
config_cp
[
KEY
.
MODAL_MAP
]),
})
else
:
config
.
update
({
KEY
.
MODAL_MAP
+
'_cp'
:
{},
KEY
.
USE_MODALITY
+
'_cp'
:
False
,
KEY
.
NUM_MODALITIES
+
'_cp'
:
0
,
})
log
.
write
(
f
'checkpoint previous epoch was:
{
from_epoch
}
\n
'
)
# decide start epoch
reset_epoch
=
continue_dct
[
KEY
.
RESET_EPOCH
]
if
reset_epoch
:
start_epoch
=
1
log
.
write
(
'epoch reset to 1
\n
'
)
else
:
start_epoch
=
from_epoch
+
1
log
.
write
(
f
'epoch start from
{
start_epoch
}
\n
'
)
# decide csv file to continue
init_csv
=
True
csv_fname
=
config_cp
[
KEY
.
CSV_LOG
]
if
os
.
path
.
isfile
(
csv_fname
):
# I hope python compare dict well
if
config_cp
[
KEY
.
ERROR_RECORD
]
==
config
[
KEY
.
ERROR_RECORD
]:
log
.
writeline
(
'Same metric, csv file will be appended'
)
init_csv
=
False
else
:
log
.
writeline
(
f
'
{
csv_fname
}
file not found, new csv file will be created'
)
log
.
writeline
(
'checkpoint loading was successful'
)
state_dicts
=
[
model_state_dict_cp
,
optimizer_state_dict_cp
,
scheduler_state_dict_cp
,
]
return
state_dicts
,
start_epoch
,
init_csv
def
convert_modality_of_checkpoint_state_dct
(
config
,
state_dicts
):
# TODO: this requires updating model state dict after seeing dataset
model_state_dict_cp
,
optimizer_state_dict_cp
,
scheduler_state_dict_cp
=
(
state_dicts
)
if
config
[
KEY
.
USE_MODALITY
]:
# current model is multimodal
num_modalities_cp
=
len
(
config
[
KEY
.
MODAL_MAP
+
'_cp'
])
append_modal_length
=
config
[
KEY
.
NUM_MODALITIES
]
-
num_modalities_cp
model_state_dict_cp
=
append_modality_to_model_dct
(
model_state_dict_cp
,
config
,
num_modalities_cp
,
append_modal_length
)
else
:
# current model is single modal
if
config
[
KEY
.
USE_MODALITY
+
'_cp'
]:
# checkpoint model is multimodal
# change model state dict to single modal, default = "common"
model_state_dict_cp
=
get_single_modal_model_dct
(
model_state_dict_cp
,
config
,
config
[
KEY
.
DEFAULT_MODAL
],
from_processing_cp
=
True
,
)
state_dicts
=
(
model_state_dict_cp
,
optimizer_state_dict_cp
,
scheduler_state_dict_cp
,
)
return
state_dicts
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py
0 → 100644
View file @
ca86f720
import
os
import
torch
import
torch.distributed
as
dist
import
sevenn._const
as
CONST
import
sevenn._keys
as
KEY
from
sevenn.logger
import
Logger
from
sevenn.train.dataload
import
file_to_dataset
,
match_reader
from
sevenn.train.dataset
import
AtomGraphDataset
from
sevenn.util
import
chemical_species_preprocess
,
onehot_to_chem
def
dataset_load
(
file
:
str
,
config
):
"""
Wrapping of dataload.file_to_dataset to suppert
graph prebuilt sevenn_data
"""
log
=
Logger
()
log
.
write
(
f
'Loading
{
file
}
\n
'
)
log
.
timer_start
(
'loading dataset'
)
if
file
.
endswith
(
'.sevenn_data'
):
dataset
=
torch
.
load
(
file
,
map_location
=
'cpu'
,
weights_only
=
False
)
else
:
reader
,
_
=
match_reader
(
config
[
KEY
.
DATA_FORMAT
],
**
config
[
KEY
.
DATA_FORMAT_ARGS
]
)
dataset
=
file_to_dataset
(
file
,
config
[
KEY
.
CUTOFF
],
config
[
KEY
.
PREPROCESS_NUM_CORES
],
reader
=
reader
,
use_modality
=
config
[
KEY
.
USE_MODALITY
],
use_weight
=
config
[
KEY
.
USE_WEIGHT
],
)
log
.
format_k_v
(
'loaded dataset size is'
,
dataset
.
len
(),
write
=
True
)
log
.
timer_end
(
'loading dataset'
,
'data set loading time'
)
return
dataset
def
calculate_shift_or_scale_from_key
(
train_set
:
AtomGraphDataset
,
key_given
,
n_chem
):
_expand
=
True
use_species_wise_shift_scale
=
False
if
key_given
==
'per_atom_energy_mean'
:
shift_or_scale
=
train_set
.
get_per_atom_energy_mean
()
elif
key_given
==
'elemwise_reference_energies'
:
shift_or_scale
=
train_set
.
get_species_ref_energy_by_linear_comb
(
n_chem
)
_expand
=
False
use_species_wise_shift_scale
=
True
elif
key_given
==
'force_rms'
:
shift_or_scale
=
train_set
.
get_force_rms
()
elif
key_given
==
'per_atom_energy_std'
:
shift_or_scale
=
train_set
.
get_statistics
(
KEY
.
PER_ATOM_ENERGY
)[
'Total'
][
'std'
]
elif
key_given
==
'elemwise_force_rms'
:
shift_or_scale
=
train_set
.
get_species_wise_force_rms
(
n_chem
)
_expand
=
False
use_species_wise_shift_scale
=
True
return
shift_or_scale
,
_expand
,
use_species_wise_shift_scale
def
handle_shift_scale
(
config
,
train_set
:
AtomGraphDataset
,
checkpoint_given
):
"""
Priority (first comes later to overwrite):
1. Float given in yaml
2. Use statistic values of checkpoint == True
3. Plain options (provided as string)
"""
log
=
Logger
()
shift
,
scale
,
conv_denominator
=
None
,
None
,
None
type_map
=
config
[
KEY
.
TYPE_MAP
]
n_chem
=
len
(
type_map
)
chem_strs
=
onehot_to_chem
(
list
(
range
(
n_chem
)),
type_map
)
log
.
writeline
(
'
\n
Calculating statistic values from dataset'
)
shift_given
=
config
[
KEY
.
SHIFT
]
scale_given
=
config
[
KEY
.
SCALE
]
_expand_shift
=
True
_expand_scale
=
True
use_species_wise_shift
=
False
use_species_wise_scale
=
False
use_modal_wise_shift
=
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
use_modal_wise_scale
=
config
[
KEY
.
USE_MODAL_WISE_SCALE
]
if
shift_given
in
CONST
.
IMPLEMENTED_SHIFT
:
shift
,
_expand_shift
,
use_species_wise_shift
=
(
calculate_shift_or_scale_from_key
(
train_set
,
shift_given
,
n_chem
)
)
if
scale_given
in
CONST
.
IMPLEMENTED_SCALE
:
scale
,
_expand_scale
,
use_species_wise_scale
=
(
calculate_shift_or_scale_from_key
(
train_set
,
scale_given
,
n_chem
)
)
if
use_modal_wise_shift
or
use_modal_wise_scale
:
atomdata_dict_sort_by_modal
=
train_set
.
get_dict_sort_by_modality
()
modal_map
=
config
[
KEY
.
MODAL_MAP
]
n_modal
=
len
(
modal_map
)
cutoff
=
config
[
KEY
.
CUTOFF
]
if
use_modal_wise_shift
:
shift
=
torch
.
zeros
((
n_modal
,
n_chem
))
if
use_modal_wise_scale
:
scale
=
torch
.
zeros
((
n_modal
,
n_chem
))
for
modal_key
,
data_list
in
atomdata_dict_sort_by_modal
.
items
():
modal_set
=
AtomGraphDataset
(
data_list
,
cutoff
,
x_is_one_hot_idx
=
True
)
if
use_modal_wise_shift
:
if
shift_given
==
'elemwise_reference_energies'
:
modal_shift
,
_expand_shift
,
use_species_wise_shift
=
(
calculate_shift_or_scale_from_key
(
modal_set
,
shift_given
,
n_chem
)
)
shift
[
modal_map
[
modal_key
]]
=
torch
.
tensor
(
modal_shift
)
# this is np.array
elif
shift_given
in
CONST
.
IMPLEMENTED_SHIFT
:
raise
NotImplementedError
(
'Currently, modal-wise shift implemented for'
'species-dependent case only.'
)
if
use_modal_wise_scale
:
if
scale_given
==
'elemwise_force_rms'
:
modal_scale
,
_expand_scale
,
use_species_wise_scale
=
(
calculate_shift_or_scale_from_key
(
modal_set
,
scale_given
,
n_chem
)
)
scale
[
modal_map
[
modal_key
]]
=
modal_scale
elif
scale_given
in
CONST
.
IMPLEMENTED_SCALE
:
raise
NotImplementedError
(
'Currently, modal-wise scale implemented for'
'species-dependent case only.'
)
avg_num_neigh
=
train_set
.
get_avg_num_neigh
()
log
.
format_k_v
(
'Average # of neighbors'
,
f
'
{
avg_num_neigh
:.
6
f
}
'
,
write
=
True
)
if
config
[
KEY
.
CONV_DENOMINATOR
]
==
'avg_num_neigh'
:
conv_denominator
=
avg_num_neigh
elif
config
[
KEY
.
CONV_DENOMINATOR
]
==
'sqrt_avg_num_neigh'
:
conv_denominator
=
avg_num_neigh
**
(
0.5
)
if
(
checkpoint_given
and
config
[
KEY
.
CONTINUE
][
KEY
.
USE_STATISTIC_VALUES_OF_CHECKPOINT
]
):
log
.
writeline
(
'Overwrite shift, scale, conv_denominator from model checkpoint'
)
# TODO: This needs refactoring
conv_denominator
=
config
[
KEY
.
CONV_DENOMINATOR
+
'_cp'
]
if
not
(
use_modal_wise_shift
or
use_modal_wise_scale
):
# Values extracted from checkpoint in processing_continue.py
if
len
(
list
(
shift
))
>
1
:
use_species_wise_shift
=
True
use_species_wise_scale
=
True
_expand_shift
=
_expand_scale
=
False
else
:
shift
=
shift
.
item
()
scale
=
scale
.
item
()
else
:
# Case of modal wise shift scale
shift_cp
=
config
[
KEY
.
SHIFT
+
'_cp'
]
scale_cp
=
config
[
KEY
.
SCALE
+
'_cp'
]
if
not
use_modal_wise_shift
:
shift
=
shift_cp
if
not
use_modal_wise_scale
:
scale
=
scale_cp
modal_map
=
config
[
KEY
.
MODAL_MAP
]
modal_map_cp
=
config
[
KEY
.
MODAL_MAP
+
'_cp'
]
# Extracting shift, scale for modal in checkpoint model.
if
config
[
KEY
.
USE_MODALITY
+
'_cp'
]:
# cp model is multimodal
for
modal_key_cp
,
modal_idx_cp
in
modal_map_cp
.
items
():
modal_idx
=
modal_map
[
modal_key_cp
]
if
use_modal_wise_shift
:
shift
[
modal_idx
]
=
torch
.
tensor
(
shift_cp
[
modal_idx_cp
])
if
use_modal_wise_scale
:
scale
[
modal_idx
]
=
torch
.
tensor
(
scale_cp
[
modal_idx_cp
])
else
:
# cp model is single modal
try
:
modal_idx
=
modal_map
[
config
[
KEY
.
DEFAULT_MODAL
]]
except
:
raise
KeyError
(
f
'
{
config
[
KEY
.
DEFAULT_MODAL
]
}
should be one of'
f
'
{
modal_map
.
keys
()
}
'
)
if
use_modal_wise_shift
:
shift
[
modal_idx
]
=
torch
.
tensor
(
shift_cp
)
if
use_modal_wise_scale
:
scale
[
modal_idx
]
=
torch
.
tensor
(
scale_cp
)
if
not
config
[
KEY
.
CONTINUE
][
KEY
.
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY
]:
# Also overwrite values of new modal to reference value
# For multimodal, set reference modal with KEY.DEFAULT_MODAL
shift_ref
=
shift_cp
scale_ref
=
scale_cp
if
config
[
KEY
.
USE_MODALITY
+
'_cp'
]:
try
:
modal_idx_cp
=
modal_map_cp
[
config
[
KEY
.
DEFAULT_MODAL
]]
except
:
raise
KeyError
(
f
'
{
config
[
KEY
.
DEFAULT_MODAL
]
}
should be one of'
f
'
{
modal_map_cp
.
keys
()
}
'
)
shift_ref
=
shift_cp
[
modal_idx_cp
]
scale_ref
=
scale_cp
[
modal_idx_cp
]
for
modal_key
,
modal_idx
in
modal_map
.
items
():
if
modal_key
not
in
modal_map_cp
.
keys
():
if
use_modal_wise_shift
:
shift
[
modal_idx
]
=
shift_ref
if
use_modal_wise_scale
:
scale
[
modal_idx
]
=
scale_ref
# overwrite shift scale anyway if defined in yaml.
if
type
(
shift_given
)
in
[
list
,
float
]:
log
.
writeline
(
'Overwrite shift to value(s) given in yaml'
)
_expand_shift
=
isinstance
(
shift_given
,
float
)
shift
=
shift_given
if
type
(
scale_given
)
in
[
list
,
float
]:
log
.
writeline
(
'Overwrite scale to value(s) given in yaml'
)
_expand_scale
=
isinstance
(
scale_given
,
float
)
scale
=
scale_given
if
isinstance
(
config
[
KEY
.
CONV_DENOMINATOR
],
float
):
log
.
writeline
(
'Overwrite conv_denominator to value given in yaml'
)
conv_denominator
=
config
[
KEY
.
CONV_DENOMINATOR
]
if
isinstance
(
conv_denominator
,
float
):
conv_denominator
=
[
conv_denominator
]
*
config
[
KEY
.
NUM_CONVOLUTION
]
use_species_wise_shift_scale
=
use_species_wise_shift
or
use_species_wise_scale
if
use_species_wise_shift_scale
:
chem_strs
=
onehot_to_chem
(
list
(
range
(
n_chem
)),
type_map
)
if
_expand_shift
:
if
use_modal_wise_shift
:
shift
=
torch
.
full
((
n_modal
,
n_chem
),
shift
)
else
:
shift
=
[
shift
]
*
n_chem
if
_expand_scale
:
if
use_modal_wise_scale
:
scale
=
torch
.
full
((
n_modal
,
n_chem
),
scale
)
else
:
scale
=
[
scale
]
*
n_chem
Logger
().
write
(
'Use element-wise shift, scale
\n
'
)
if
use_modal_wise_shift
or
use_modal_wise_scale
:
for
modal_key
,
modal_idx
in
modal_map
.
items
():
Logger
().
writeline
(
f
'For modal =
{
modal_key
}
'
)
print_shift
=
shift
[
modal_idx
]
if
use_modal_wise_shift
else
shift
print_scale
=
scale
[
modal_idx
]
if
use_modal_wise_scale
else
scale
for
cstr
,
sh
,
sc
in
zip
(
chem_strs
,
print_shift
,
print_scale
):
Logger
().
format_k_v
(
f
'
{
cstr
}
'
,
f
'
{
sh
:.
6
f
}
,
{
sc
:.
6
f
}
'
,
write
=
True
)
else
:
for
cstr
,
sh
,
sc
in
zip
(
chem_strs
,
shift
,
scale
):
Logger
().
format_k_v
(
f
'
{
cstr
}
'
,
f
'
{
sh
:.
6
f
}
,
{
sc
:.
6
f
}
'
,
write
=
True
)
else
:
log
.
write
(
'Use global shift, scale
\n
'
)
log
.
format_k_v
(
'shift, scale'
,
f
'
{
shift
:.
6
f
}
,
{
scale
:.
6
f
}
'
,
write
=
True
)
assert
isinstance
(
conv_denominator
,
list
)
and
all
(
isinstance
(
deno
,
float
)
for
deno
in
conv_denominator
)
log
.
format_k_v
(
'(1st) conv_denominator is'
,
f
'
{
conv_denominator
[
0
]:.
6
f
}
'
,
write
=
True
)
config
[
KEY
.
USE_SPECIES_WISE_SHIFT_SCALE
]
=
use_species_wise_shift_scale
return
shift
,
scale
,
conv_denominator
# TODO: This is too long
def
processing_dataset
(
config
,
working_dir
):
log
=
Logger
()
prefix
=
f
'
{
os
.
path
.
abspath
(
working_dir
)
}
/'
is_stress
=
config
[
KEY
.
IS_TRAIN_STRESS
]
checkpoint_given
=
config
[
KEY
.
CONTINUE
][
KEY
.
CHECKPOINT
]
is
not
False
cutoff
=
config
[
KEY
.
CUTOFF
]
log
.
write
(
'
\n
Initializing dataset...
\n
'
)
dataset
=
AtomGraphDataset
({},
cutoff
)
load_dataset
=
config
[
KEY
.
LOAD_DATASET
]
if
type
(
load_dataset
)
is
str
:
load_dataset
=
[
load_dataset
]
for
file
in
load_dataset
:
dataset
.
augment
(
dataset_load
(
file
,
config
))
dataset
.
group_by_key
()
# apply labels inside original datapoint
dataset
.
unify_dtypes
()
# unify dtypes of all data points
# TODO: I think manual chemical species input is redundant
chem_in_db
=
dataset
.
get_species
()
if
config
[
KEY
.
CHEMICAL_SPECIES
]
==
'auto'
and
not
checkpoint_given
:
log
.
writeline
(
'Auto detect chemical species from dataset'
)
config
.
update
(
chemical_species_preprocess
(
chem_in_db
))
elif
config
[
KEY
.
CHEMICAL_SPECIES
]
==
'auto'
and
checkpoint_given
:
pass
# copied from checkpoint in processing_continue.py
elif
config
[
KEY
.
CHEMICAL_SPECIES
]
!=
'auto'
and
not
checkpoint_given
:
pass
# processed in parse_input.py
else
:
# config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given
log
.
writeline
(
'Ignore chemical species in yaml, use checkpoint'
)
# already processed in processing_continue.py
# basic dataset compatibility check with previous model
if
checkpoint_given
:
chem_from_cp
=
config
[
KEY
.
CHEMICAL_SPECIES
]
if
not
all
(
chem
in
chem_from_cp
for
chem
in
chem_in_db
):
raise
ValueError
(
'Chemical species in checkpoint is not compatible'
)
# check what modalities are used in dataset
if
config
[
KEY
.
USE_MODALITY
]:
modalities
=
dataset
.
get_modalities
()
num_modalities
=
len
(
modalities
)
if
num_modalities
<
2
:
Logger
().
writeline
(
'Only one modal is given, ignore modality'
)
config
.
uptate
({
KEY
.
USE_MODALITY
:
False
})
else
:
modal_map_cp
=
config
[
KEY
.
MODAL_MAP
+
'_cp'
]
if
checkpoint_given
else
{}
modal_map
=
modal_map_cp
.
copy
()
current_idx
=
len
(
modal_map_cp
)
for
modal_key
in
modalities
:
if
modal_key
not
in
modal_map
.
keys
():
modal_map
[
modal_key
]
=
current_idx
current_idx
+=
1
if
config
[
KEY
.
IS_DDP
]:
# Synchronize modal_map
torch
.
cuda
.
set_device
(
config
[
KEY
.
LOCAL_RANK
])
modal_map_bcast
=
[
modal_map
]
dist
.
broadcast_object_list
(
modal_map_bcast
,
src
=
0
)
modal_map
=
modal_map_bcast
[
0
]
config
.
update
(
{
KEY
.
NUM_MODALITIES
:
len
(
modal_map
),
KEY
.
MODAL_MAP
:
modal_map
,
KEY
.
MODAL_LIST
:
list
(
modal_map
.
keys
()),
}
)
dataset
.
write_modal_attr
(
modal_map
,
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
or
config
[
KEY
.
USE_MODAL_WISE_SCALE
],
)
# --------------- save dataset regardless of train/valid--------------#
save_dataset
=
config
[
KEY
.
SAVE_DATASET
]
save_by_label
=
config
[
KEY
.
SAVE_BY_LABEL
]
if
save_dataset
:
if
save_dataset
.
endswith
(
'.sevenn_data'
)
is
False
:
save_dataset
+=
'.sevenn_data'
if
(
save_dataset
.
startswith
(
'.'
)
or
save_dataset
.
startswith
(
'/'
))
is
False
:
save_dataset
=
prefix
+
save_dataset
# save_data set is plain file name
dataset
.
save
(
save_dataset
)
log
.
format_k_v
(
'Dataset saved to'
,
save_dataset
,
write
=
True
)
# log.write(f"Loaded full dataset saved to : {save_dataset}\n")
if
save_by_label
:
dataset
.
save
(
prefix
,
by_label
=
True
)
log
.
format_k_v
(
'Dataset saved by label'
,
prefix
,
write
=
True
)
# --------------------------------------------------------------------#
# TODO: testset is not used
ignore_test
=
not
config
.
get
(
KEY
.
USE_TESTSET
,
False
)
if
KEY
.
LOAD_VALIDSET
in
config
and
config
[
KEY
.
LOAD_VALIDSET
]:
train_set
=
dataset
test_set
=
AtomGraphDataset
([],
config
[
KEY
.
CUTOFF
])
log
.
write
(
'Loading validset from load_validset
\n
'
)
valid_set
=
AtomGraphDataset
({},
cutoff
)
for
file
in
config
[
KEY
.
LOAD_VALIDSET
]:
valid_set
.
augment
(
dataset_load
(
file
,
config
))
valid_set
.
group_by_key
()
valid_set
.
unify_dtypes
()
# condition: validset labels should be subset of trainset labels
valid_labels
=
valid_set
.
user_labels
train_labels
=
train_set
.
user_labels
if
set
(
valid_labels
).
issubset
(
set
(
train_labels
))
is
False
:
valid_set
=
AtomGraphDataset
(
valid_set
.
to_list
(),
cutoff
)
valid_set
.
rewrite_labels_to_data
()
train_set
=
AtomGraphDataset
(
train_set
.
to_list
(),
cutoff
)
train_set
.
rewrite_labels_to_data
()
Logger
().
write
(
'WARNING! validset labels is not subset of trainset
\n
'
)
Logger
().
write
(
'We overwrite all the train, valid labels to default.
\n
'
)
Logger
().
write
(
'Please create validset by sevenn_graph_build with -l
\n
'
)
Logger
().
write
(
'the validset loaded, load_dataset is now train_set
\n
'
)
Logger
().
write
(
'the ratio will be ignored
\n
'
)
# condition: validset modalities should be subset of trainset modalities
if
config
[
KEY
.
USE_MODALITY
]:
config_modality
=
config
[
KEY
.
MODAL_LIST
]
valid_modality
=
valid_set
.
get_modalities
()
if
set
(
valid_modality
).
issubset
(
set
(
config_modality
))
is
False
:
raise
ValueError
(
'validset modality is not subset of trainset'
)
valid_set
.
write_modal_attr
(
config
[
KEY
.
MODAL_MAP
],
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]
or
config
[
KEY
.
USE_MODAL_WISE_SCALE
],
)
else
:
train_set
,
valid_set
,
test_set
=
dataset
.
divide_dataset
(
config
[
KEY
.
RATIO
],
ignore_test
=
ignore_test
)
log
.
write
(
f
'The dataset divided into train, valid by
{
KEY
.
RATIO
}
\n
'
)
log
.
format_k_v
(
'
\n
loaded trainset size is'
,
train_set
.
len
(),
write
=
True
)
log
.
format_k_v
(
'
\n
loaded validset size is'
,
valid_set
.
len
(),
write
=
True
)
log
.
write
(
'Dataset initialization was successful
\n
'
)
log
.
write
(
'
\n
Number of atoms in the train_set:
\n
'
)
log
.
natoms_write
(
train_set
.
get_natoms
(
config
[
KEY
.
TYPE_MAP
]))
log
.
bar
()
log
.
write
(
'Per atom energy(eV/atom) distribution:
\n
'
)
log
.
statistic_write
(
train_set
.
get_statistics
(
KEY
.
PER_ATOM_ENERGY
))
log
.
bar
()
log
.
write
(
'Force(eV/Angstrom) distribution:
\n
'
)
log
.
statistic_write
(
train_set
.
get_statistics
(
KEY
.
FORCE
))
log
.
bar
()
log
.
write
(
'Stress(eV/Angstrom^3) distribution:
\n
'
)
try
:
log
.
statistic_write
(
train_set
.
get_statistics
(
KEY
.
STRESS
))
except
KeyError
:
log
.
write
(
'
\n
Stress is not included in the train_set
\n
'
)
if
is_stress
:
is_stress
=
False
log
.
write
(
'Turn off stress training
\n
'
)
log
.
bar
()
# saved data must have atomic numbers as X not one hot idx
if
config
[
KEY
.
SAVE_BY_TRAIN_VALID
]:
train_set
.
save
(
prefix
+
'train'
)
valid_set
.
save
(
prefix
+
'valid'
)
log
.
format_k_v
(
'Dataset saved by train, valid'
,
prefix
,
write
=
True
)
# inconsistent .info dict give error when collate
_
,
_
=
train_set
.
separate_info
()
_
,
_
=
valid_set
.
separate_info
()
if
train_set
.
x_is_one_hot_idx
is
False
:
train_set
.
x_to_one_hot_idx
(
config
[
KEY
.
TYPE_MAP
])
if
valid_set
.
x_is_one_hot_idx
is
False
:
valid_set
.
x_to_one_hot_idx
(
config
[
KEY
.
TYPE_MAP
])
log
.
format_k_v
(
'training_set size'
,
train_set
.
len
(),
write
=
True
)
log
.
format_k_v
(
'validation_set size'
,
valid_set
.
len
(),
write
=
True
)
shift
,
scale
,
conv_denominator
=
handle_shift_scale
(
config
,
train_set
,
checkpoint_given
)
config
.
update
(
{
KEY
.
SHIFT
:
shift
,
KEY
.
SCALE
:
scale
,
KEY
.
CONV_DENOMINATOR
:
conv_denominator
,
}
)
data_lists
=
(
train_set
.
to_list
(),
valid_set
.
to_list
(),
test_set
.
to_list
())
return
data_lists
mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py
0 → 100644
View file @
ca86f720
import
os
from
copy
import
deepcopy
from
typing
import
Optional
import
torch
from
torch.utils.data.distributed
import
DistributedSampler
import
sevenn._keys
as
KEY
from
sevenn.error_recorder
import
ErrorRecorder
from
sevenn.logger
import
Logger
from
sevenn.train.trainer
import
Trainer
def
processing_epoch_v2
(
config
:
dict
,
trainer
:
Trainer
,
loaders
:
dict
,
# dict[str, Dataset]
start_epoch
:
int
=
1
,
train_loader_key
:
str
=
'trainset'
,
error_recorder
:
Optional
[
ErrorRecorder
]
=
None
,
total_epoch
:
Optional
[
int
]
=
None
,
per_epoch
:
Optional
[
int
]
=
None
,
best_metric_loader_key
:
str
=
'validset'
,
best_metric
:
Optional
[
str
]
=
None
,
write_csv
:
bool
=
True
,
working_dir
:
Optional
[
str
]
=
None
,
):
from
sevenn.util
import
unique_filepath
log
=
Logger
()
write_csv
=
write_csv
and
log
.
rank
==
0
working_dir
=
working_dir
or
os
.
getcwd
()
prefix
=
f
'
{
os
.
path
.
abspath
(
working_dir
)
}
/'
total_epoch
=
total_epoch
or
config
[
KEY
.
EPOCH
]
per_epoch
=
per_epoch
or
config
.
get
(
KEY
.
PER_EPOCH
,
10
)
best_metric
=
best_metric
or
config
.
get
(
KEY
.
BEST_METRIC
,
'TotalLoss'
)
recorder
=
error_recorder
or
ErrorRecorder
.
from_config
(
config
,
trainer
.
loss_functions
)
recorders
=
{
k
:
deepcopy
(
recorder
)
for
k
in
loaders
}
best_val
=
float
(
'inf'
)
best_key
=
None
if
best_metric_loader_key
in
recorders
:
best_key
=
recorders
[
best_metric_loader_key
].
get_key_str
(
best_metric
)
if
best_key
is
None
:
log
.
writeline
(
f
'Failed to get error recorder key:
{
best_metric
}
or '
+
f
'
{
best_metric_loader_key
}
is missing. There will be no best '
+
'checkpoint.'
)
csv_path
=
unique_filepath
(
f
'
{
prefix
}
/lc.csv'
)
if
write_csv
:
head
=
[
'epoch'
,
'lr'
]
for
k
,
rec
in
recorders
.
items
():
head
.
extend
(
list
(
rec
.
get_dct
(
prefix
=
k
)))
with
open
(
csv_path
,
'w'
)
as
f
:
f
.
write
(
','
.
join
(
head
)
+
'
\n
'
)
if
start_epoch
==
1
:
path
=
f
'
{
prefix
}
/checkpoint_0.pth'
# save first epoch
trainer
.
write_checkpoint
(
path
,
config
=
config
,
epoch
=
0
)
for
epoch
in
range
(
start_epoch
,
total_epoch
+
1
):
# one indexing
log
.
timer_start
(
'epoch'
)
lr
=
trainer
.
get_lr
()
log
.
bar
()
log
.
write
(
f
'Epoch
{
epoch
}
/
{
total_epoch
}
lr:
{
lr
:
8
f
}
\n
'
)
log
.
bar
()
csv_dct
=
{
'epoch'
:
str
(
epoch
),
'lr'
:
f
'
{
lr
:
8
f
}
'
}
errors
=
{}
for
k
,
loader
in
loaders
.
items
():
is_train
=
k
==
train_loader_key
if
(
trainer
.
distributed
and
isinstance
(
loader
.
sampler
,
DistributedSampler
)
and
is_train
and
config
.
get
(
'train_shuffle'
,
True
)
):
loader
.
sampler
.
set_epoch
(
epoch
)
rec
=
recorders
[
k
]
trainer
.
run_one_epoch
(
loader
,
is_train
,
rec
)
csv_dct
.
update
(
rec
.
get_dct
(
prefix
=
k
))
errors
[
k
]
=
rec
.
epoch_forward
()
log
.
write_full_table
(
list
(
errors
.
values
()),
list
(
errors
))
trainer
.
scheduler_step
(
best_val
)
if
write_csv
:
with
open
(
csv_path
,
'a'
)
as
f
:
f
.
write
(
','
.
join
(
list
(
csv_dct
.
values
()))
+
'
\n
'
)
if
best_key
and
errors
[
best_metric_loader_key
][
best_key
]
<
best_val
:
path
=
f
'
{
prefix
}
/checkpoint_best.pth'
trainer
.
write_checkpoint
(
path
,
config
=
config
,
epoch
=
epoch
)
best_val
=
errors
[
best_metric_loader_key
][
best_key
]
log
.
writeline
(
'Best checkpoint written'
)
if
epoch
%
per_epoch
==
0
:
path
=
f
'
{
prefix
}
/checkpoint_
{
epoch
}
.pth'
trainer
.
write_checkpoint
(
path
,
config
=
config
,
epoch
=
epoch
)
log
.
timer_end
(
'epoch'
,
message
=
f
'Epoch
{
epoch
}
elapsed'
)
return
trainer
def
processing_epoch
(
trainer
,
config
,
loaders
,
start_epoch
,
init_csv
,
working_dir
):
log
=
Logger
()
prefix
=
f
'
{
os
.
path
.
abspath
(
working_dir
)
}
/'
train_loader
,
valid_loader
=
loaders
is_distributed
=
config
[
KEY
.
IS_DDP
]
rank
=
config
[
KEY
.
RANK
]
total_epoch
=
config
[
KEY
.
EPOCH
]
per_epoch
=
config
[
KEY
.
PER_EPOCH
]
train_recorder
=
ErrorRecorder
.
from_config
(
config
)
valid_recorder
=
ErrorRecorder
.
from_config
(
config
)
best_metric
=
config
[
KEY
.
BEST_METRIC
]
csv_fname
=
f
'
{
prefix
}{
config
[
KEY
.
CSV_LOG
]
}
'
current_best
=
float
(
'inf'
)
if
init_csv
:
csv_header
=
[
'Epoch'
,
'Learning_rate'
]
# Assume train valid have the same metrics
for
metric
in
train_recorder
.
get_metric_dict
().
keys
():
csv_header
.
append
(
f
'Train_
{
metric
}
'
)
csv_header
.
append
(
f
'Valid_
{
metric
}
'
)
log
.
init_csv
(
csv_fname
,
csv_header
)
def
write_checkpoint
(
epoch
,
is_best
=
False
):
if
is_distributed
and
rank
!=
0
:
return
suffix
=
'_best'
if
is_best
else
f
'_
{
epoch
}
'
checkpoint
=
trainer
.
get_checkpoint_dict
()
checkpoint
.
update
({
'config'
:
config
,
'epoch'
:
epoch
})
torch
.
save
(
checkpoint
,
f
'
{
prefix
}
/checkpoint
{
suffix
}
.pth'
)
fin_epoch
=
total_epoch
+
start_epoch
for
epoch
in
range
(
start_epoch
,
fin_epoch
):
lr
=
trainer
.
get_lr
()
log
.
timer_start
(
'epoch'
)
log
.
bar
()
log
.
write
(
f
'Epoch
{
epoch
}
/
{
fin_epoch
-
1
}
lr:
{
lr
:
8
f
}
\n
'
)
log
.
bar
()
trainer
.
run_one_epoch
(
train_loader
,
is_train
=
True
,
error_recorder
=
train_recorder
)
train_err
=
train_recorder
.
epoch_forward
()
trainer
.
run_one_epoch
(
valid_loader
,
error_recorder
=
valid_recorder
)
valid_err
=
valid_recorder
.
epoch_forward
()
csv_values
=
[
epoch
,
lr
]
for
metric
in
train_err
:
csv_values
.
append
(
train_err
[
metric
])
csv_values
.
append
(
valid_err
[
metric
])
log
.
append_csv
(
csv_fname
,
csv_values
)
log
.
write_full_table
([
train_err
,
valid_err
],
[
'Train'
,
'Valid'
])
val
=
None
for
metric
in
valid_err
:
# loose string comparison,
# e.g. "Energy" in "TotalEnergy" or "Energy_Loss"
if
best_metric
in
metric
:
val
=
valid_err
[
metric
]
break
assert
val
is
not
None
,
f
'Metric
{
best_metric
}
not found in
{
valid_err
}
'
trainer
.
scheduler_step
(
val
)
log
.
timer_end
(
'epoch'
,
message
=
f
'Epoch
{
epoch
}
elapsed'
)
if
val
<
current_best
:
current_best
=
val
write_checkpoint
(
epoch
,
is_best
=
True
)
log
.
writeline
(
'Best checkpoint written'
)
if
epoch
%
per_epoch
==
0
:
write_checkpoint
(
epoch
)
mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py
0 → 100644
View file @
ca86f720
from
typing
import
List
,
Optional
import
torch.distributed
as
dist
from
torch.utils.data.distributed
import
DistributedSampler
from
torch_geometric.loader
import
DataLoader
import
sevenn._keys
as
KEY
from
sevenn.logger
import
Logger
from
sevenn.model_build
import
build_E3_equivariant_model
from
sevenn.scripts.processing_continue
import
(
convert_modality_of_checkpoint_state_dct
,
)
from
sevenn.train.trainer
import
Trainer
def
loader_from_config
(
config
,
dataset
,
is_train
=
False
):
batch_size
=
config
[
KEY
.
BATCH_SIZE
]
shuffle
=
is_train
and
config
[
KEY
.
TRAIN_SHUFFLE
]
sampler
=
None
loader_args
=
{
'dataset'
:
dataset
,
'batch_size'
:
batch_size
,
'shuffle'
:
shuffle
}
if
KEY
.
NUM_WORKERS
in
config
and
config
[
KEY
.
NUM_WORKERS
]
>
0
:
loader_args
.
update
({
'num_workers'
:
config
[
KEY
.
NUM_WORKERS
]})
if
config
[
KEY
.
IS_DDP
]:
dist
.
barrier
()
sampler
=
DistributedSampler
(
dataset
,
dist
.
get_world_size
(),
dist
.
get_rank
(),
shuffle
=
shuffle
)
loader_args
.
update
({
'sampler'
:
sampler
})
loader_args
.
pop
(
'shuffle'
)
# sampler is mutually exclusive with shuffle
return
DataLoader
(
**
loader_args
)
def
train_v2
(
config
,
working_dir
:
str
):
"""
Main program flow, since v0.9.6
"""
import
sevenn.train.atoms_dataset
as
atoms_dataset
import
sevenn.train.graph_dataset
as
graph_dataset
import
sevenn.train.modal_dataset
as
modal_dataset
from
.processing_continue
import
processing_continue_v2
from
.processing_epoch
import
processing_epoch_v2
log
=
Logger
()
log
.
timer_start
(
'total'
)
if
KEY
.
LOAD_TRAINSET
not
in
config
and
KEY
.
LOAD_DATASET
in
config
:
log
.
writeline
(
'***************************************************'
)
log
.
writeline
(
'For train_v2, please use load_trainset_path instead'
)
log
.
writeline
(
'I will assign load_trainset as load_dataset'
)
log
.
writeline
(
'***************************************************'
)
config
[
KEY
.
LOAD_TRAINSET
]
=
config
.
pop
(
KEY
.
LOAD_DATASET
)
# config updated
start_epoch
=
1
state_dicts
:
Optional
[
List
[
dict
]]
=
None
if
config
[
KEY
.
CONTINUE
][
KEY
.
CHECKPOINT
]:
state_dicts
,
start_epoch
=
processing_continue_v2
(
config
)
if
config
.
get
(
KEY
.
USE_MODALITY
,
False
):
datasets
=
modal_dataset
.
from_config
(
config
,
working_dir
)
elif
config
[
KEY
.
DATASET_TYPE
]
==
'graph'
:
datasets
=
graph_dataset
.
from_config
(
config
,
working_dir
)
elif
config
[
KEY
.
DATASET_TYPE
]
==
'atoms'
:
datasets
=
atoms_dataset
.
from_config
(
config
,
working_dir
)
else
:
raise
ValueError
(
f
'Unknown dataset type:
{
config
[
KEY
.
DATASET_TYPE
]
}
'
)
loaders
=
{
k
:
loader_from_config
(
config
,
v
,
is_train
=
(
k
==
'trainset'
))
for
k
,
v
in
datasets
.
items
()
}
log
.
write
(
'
\n
Model building...
\n
'
)
model
=
build_E3_equivariant_model
(
config
)
log
.
print_model_info
(
model
,
config
)
trainer
=
Trainer
.
from_config
(
model
,
config
)
if
state_dicts
:
trainer
.
load_state_dicts
(
*
state_dicts
,
strict
=
False
)
processing_epoch_v2
(
config
,
trainer
,
loaders
,
start_epoch
,
working_dir
=
working_dir
)
log
.
timer_end
(
'total'
,
message
=
'Total wall time'
)
def
train
(
config
,
working_dir
:
str
):
"""
Main program flow, until v0.9.5
"""
from
.processing_continue
import
processing_continue
from
.processing_dataset
import
processing_dataset
from
.processing_epoch
import
processing_epoch
log
=
Logger
()
log
.
timer_start
(
'total'
)
# config updated
state_dicts
:
Optional
[
List
[
dict
]]
=
None
if
config
[
KEY
.
CONTINUE
][
KEY
.
CHECKPOINT
]:
state_dicts
,
start_epoch
,
init_csv
=
processing_continue
(
config
)
else
:
start_epoch
,
init_csv
=
1
,
True
# config updated
train
,
valid
,
_
=
processing_dataset
(
config
,
working_dir
)
datasets
=
{
'dataset'
:
train
,
'validset'
:
valid
}
loaders
=
{
k
:
loader_from_config
(
config
,
v
,
is_train
=
(
k
==
'dataset'
))
for
k
,
v
in
datasets
.
items
()
}
loaders
=
list
(
loaders
.
values
())
log
.
write
(
'
\n
Model building...
\n
'
)
model
=
build_E3_equivariant_model
(
config
)
log
.
write
(
'Model building was successful
\n
'
)
trainer
=
Trainer
.
from_config
(
model
,
config
)
if
state_dicts
:
state_dicts
=
convert_modality_of_checkpoint_state_dct
(
config
,
state_dicts
)
trainer
.
load_state_dicts
(
*
state_dicts
,
strict
=
False
)
log
.
print_model_info
(
model
,
config
)
Logger
().
write
(
'Trainer initialized, ready to training
\n
'
)
Logger
().
bar
()
log
.
write
(
'Trainer initialized, ready to training
\n
'
)
log
.
bar
()
processing_epoch
(
trainer
,
config
,
loaders
,
start_epoch
,
init_csv
,
working_dir
)
log
.
timer_end
(
'total'
,
message
=
'Total wall time'
)
mace-bench/3rdparty/SevenNet/sevenn/train/__init__.py
0 → 100644
View file @
ca86f720
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ca86f720
File added
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc
0 → 100644
View file @
ca86f720
File added
mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc
0 → 100644
View file @
ca86f720
File added
mace-bench/3rdparty/SevenNet/sevenn/train/atoms_dataset.py
0 → 100644
View file @
ca86f720
import
os
import
random
import
warnings
from
collections
import
Counter
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch.utils.data
from
ase.atoms
import
Atoms
from
ase.data
import
chemical_symbols
from
ase.io
import
write
from
tqdm
import
tqdm
import
sevenn._keys
as
KEY
import
sevenn.train.dataload
as
dataload
import
sevenn.util
as
util
from
sevenn._const
import
NUM_UNIV_ELEMENT
from
sevenn.atom_graph_data
import
AtomGraphData
_warn_avg_num_neigh
=
"""SevenNetAtomsDataset does not provide correct avg_num_neigh
as it does not build graph. We will compute only random 10000 structures graph to
approximate this value. If you want more precise avg_num_neigh,
use SevenNetGraphDataset. If it is not viable due to memory limit, you
need online algorithm to do this , which is not yet implemented in the SevenNet"""
class
SevenNetAtomsDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Args:
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing.
default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed
while training.
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def
__init__
(
self
,
cutoff
:
float
,
files
:
Union
[
str
,
List
[
str
]],
atoms_filter
:
Optional
[
Callable
]
=
None
,
atoms_transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
use_data_weight
:
bool
=
False
,
**
process_kwargs
,
):
self
.
cutoff
=
cutoff
if
isinstance
(
files
,
str
):
files
=
[
files
]
# user convenience
files
=
[
os
.
path
.
abspath
(
file
)
for
file
in
files
]
self
.
_files
=
files
self
.
atoms_filter
=
atoms_filter
self
.
atoms_transform
=
atoms_transform
self
.
transform
=
transform
self
.
use_data_weight
=
use_data_weight
self
.
_scanned
=
False
self
.
_avg_num_neigh_approx
=
None
self
.
statistics
=
{}
atoms_list
=
[]
for
file
in
files
:
atoms_list
.
extend
(
SevenNetAtomsDataset
.
file_to_atoms_list
(
file
,
**
process_kwargs
)
)
self
.
_atoms_list
=
atoms_list
super
().
__init__
()
@
staticmethod
def
file_to_atoms_list
(
file
:
Union
[
str
,
dict
],
**
kwargs
)
->
List
[
Atoms
]:
if
isinstance
(
file
,
dict
):
atoms_list
=
dataload
.
dict_reader
(
file
)
elif
'structure_list'
in
file
:
atoms_dct
=
dataload
.
structure_list_reader
(
file
)
atoms_list
=
[]
for
lst
in
atoms_dct
.
values
():
atoms_list
.
extend
(
lst
)
else
:
atoms_list
=
dataload
.
ase_reader
(
file
,
**
kwargs
)
return
atoms_list
def
save
(
self
,
path
):
# Save atoms list as extxyz
write
(
path
,
self
.
_atoms_list
,
format
=
'extxyz'
)
def
_graph_build
(
self
,
atoms
):
return
dataload
.
atoms_to_graph
(
atoms
,
self
.
cutoff
,
transfer_info
=
False
,
y_from_calc
=
False
)
def
__len__
(
self
):
return
len
(
self
.
_atoms_list
)
def
__getitem__
(
self
,
index
):
atoms
=
self
.
_atoms_list
[
index
]
if
self
.
atoms_transform
is
not
None
:
atoms
=
self
.
atoms_transform
(
atoms
)
graph
=
self
.
_graph_build
(
atoms
)
if
self
.
transform
is
not
None
:
graph
=
self
.
transform
(
graph
)
if
self
.
use_data_weight
:
weight
=
graph
[
KEY
.
INFO
].
pop
(
KEY
.
DATA_WEIGHT
,
{
'energy'
:
1.0
,
'force'
:
1.0
,
'stress'
:
1.0
}
)
graph
[
KEY
.
DATA_WEIGHT
]
=
weight
return
AtomGraphData
.
from_numpy_dict
(
graph
)
@
property
def
species
(
self
):
self
.
run_stat
()
return
[
z
for
z
in
self
.
statistics
[
'_natoms'
].
keys
()
if
z
!=
'total'
]
@
property
def
natoms
(
self
):
self
.
run_stat
()
return
self
.
statistics
[
'_natoms'
]
@
property
def
per_atom_energy_mean
(
self
):
self
.
run_stat
()
return
self
.
statistics
[
KEY
.
PER_ATOM_ENERGY
][
'mean'
]
@
property
def
elemwise_reference_energies
(
self
):
from
sklearn.linear_model
import
Ridge
c
=
self
.
statistics
[
'_composition'
]
y
=
self
.
statistics
[
KEY
.
ENERGY
][
'_array'
]
zero_indices
=
np
.
all
(
c
==
0
,
axis
=
0
)
c_reduced
=
c
[:,
~
zero_indices
]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced
=
Ridge
(
alpha
=
0.1
,
fit_intercept
=
False
).
fit
(
c_reduced
,
y
).
coef_
full_coeff
=
np
.
zeros
(
NUM_UNIV_ELEMENT
)
full_coeff
[
~
zero_indices
]
=
coef_reduced
return
full_coeff
.
tolist
()
# ex: full_coeff[1] = H_reference_energy
@
property
def
force_rms
(
self
):
self
.
run_stat
()
mean
=
self
.
statistics
[
KEY
.
FORCE
][
'mean'
]
std
=
self
.
statistics
[
KEY
.
FORCE
][
'std'
]
return
float
((
mean
**
2
+
std
**
2
)
**
(
0.5
))
@
property
def
per_atom_energy_std
(
self
):
self
.
run_stat
()
return
self
.
statistics
[
'per_atom_energy'
][
'std'
]
@
property
def
avg_num_neigh
(
self
,
n_sample
=
10000
):
if
self
.
_avg_num_neigh_approx
is
None
:
if
len
(
self
)
>
n_sample
:
warnings
.
warn
(
_warn_avg_num_neigh
)
n_sample
=
min
(
len
(
self
),
n_sample
)
indices
=
random
.
sample
(
range
(
len
(
self
)),
n_sample
)
n_neigh
=
[]
for
i
in
indices
:
graph
=
self
[
i
]
_
,
nn
=
np
.
unique
(
graph
[
KEY
.
EDGE_IDX
][
0
],
return_counts
=
True
)
n_neigh
.
append
(
nn
)
n_neigh
=
np
.
concatenate
(
n_neigh
)
self
.
_avg_num_neigh_approx
=
np
.
mean
(
n_neigh
)
return
self
.
_avg_num_neigh_approx
@
property
def
sqrt_avg_num_neigh
(
self
):
self
.
run_stat
()
return
self
.
avg_num_neigh
**
0.5
def
run_stat
(
self
):
"""
Loop over dataset and init any statistics might need
Unlink SevenNetGraphDataset, neighbors count is not computed as
it requires to build graph
"""
if
self
.
_scanned
is
True
:
return
# statistics already computed
y_keys
:
List
[
str
]
=
[
KEY
.
ENERGY
,
KEY
.
PER_ATOM_ENERGY
,
KEY
.
FORCE
,
KEY
.
STRESS
]
natoms_counter
=
Counter
()
composition
=
np
.
zeros
((
len
(
self
),
NUM_UNIV_ELEMENT
))
stats
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{
y
:
{
'_array'
:
[]}
for
y
in
y_keys
}
for
i
,
atoms
in
tqdm
(
enumerate
(
self
.
_atoms_list
),
desc
=
'run_stat'
,
total
=
len
(
self
)
):
z
=
atoms
.
get_atomic_numbers
()
natoms_counter
.
update
(
z
.
tolist
())
composition
[
i
]
=
np
.
bincount
(
z
,
minlength
=
NUM_UNIV_ELEMENT
)
for
y
,
dct
in
stats
.
items
():
if
y
==
KEY
.
ENERGY
:
dct
[
'_array'
].
append
(
atoms
.
info
[
'y_energy'
])
elif
y
==
KEY
.
PER_ATOM_ENERGY
:
dct
[
'_array'
].
append
(
atoms
.
info
[
'y_energy'
]
/
len
(
atoms
))
elif
y
==
KEY
.
FORCE
:
dct
[
'_array'
].
append
(
atoms
.
arrays
[
'y_force'
].
reshape
(
-
1
))
elif
y
==
KEY
.
STRESS
:
dct
[
'_array'
].
append
(
atoms
.
info
[
'y_stress'
].
reshape
(
-
1
))
for
y
,
dct
in
stats
.
items
():
if
y
==
KEY
.
FORCE
:
array
=
np
.
concatenate
(
dct
[
'_array'
])
else
:
array
=
np
.
array
(
dct
[
'_array'
]).
reshape
(
-
1
)
dct
.
update
(
{
'mean'
:
float
(
np
.
mean
(
array
)),
'std'
:
float
(
np
.
std
(
array
)),
'median'
:
float
(
np
.
quantile
(
array
,
q
=
0.5
)),
'max'
:
float
(
np
.
max
(
array
)),
'min'
:
float
(
np
.
min
(
array
)),
'_array'
:
array
,
}
)
natoms
=
{
chemical_symbols
[
int
(
z
)]:
cnt
for
z
,
cnt
in
natoms_counter
.
items
()}
natoms
[
'total'
]
=
sum
(
list
(
natoms
.
values
()))
self
.
statistics
.
update
(
{
'_composition'
:
composition
,
'_natoms'
:
natoms
,
**
stats
,
}
)
self
.
_scanned
=
True
# script, return dict of SevenNetAtomsDataset
def
from_config
(
config
:
Dict
[
str
,
Any
],
working_dir
:
str
=
os
.
getcwd
(),
dataset_keys
:
Optional
[
List
[
str
]]
=
None
,
):
from
sevenn.logger
import
Logger
log
=
Logger
()
if
dataset_keys
is
None
:
dataset_keys
=
[]
for
k
in
config
:
if
k
.
startswith
(
'load_'
)
and
k
.
endswith
(
'_path'
):
dataset_keys
.
append
(
k
)
if
KEY
.
LOAD_TRAINSET
not
in
dataset_keys
:
raise
ValueError
(
f
'
{
KEY
.
LOAD_TRAINSET
}
must be present in config'
)
# initialize arguments for loading dataset
dataset_args
=
{
'cutoff'
:
config
[
KEY
.
CUTOFF
],
'use_data_weight'
:
config
.
get
(
KEY
.
USE_WEIGHT
,
False
),
**
config
[
KEY
.
DATA_FORMAT_ARGS
],
}
datasets
=
{}
for
dk
in
dataset_keys
:
if
not
(
paths
:
=
config
[
dk
]):
continue
if
isinstance
(
paths
,
str
):
paths
=
[
paths
]
name
=
'_'
.
join
([
nn
.
strip
()
for
nn
in
dk
.
split
(
'_'
)[
1
:
-
1
]])
dataset_args
.
update
({
'files'
:
paths
})
datasets
[
name
]
=
SevenNetAtomsDataset
(
**
dataset_args
)
if
not
config
[
KEY
.
COMPUTE_STATISTICS
]:
log
.
writeline
(
(
'Computing statistics is skipped, note that if any of other'
'configurations requires statistics (shift, scale, avg_num_neigh,'
'chemical_species as auto), SevenNet eventually raise an error!'
)
)
return
datasets
train_set
=
datasets
[
'trainset'
]
chem_species
=
set
(
train_set
.
species
)
# print statistics of each dataset
for
name
,
dataset
in
datasets
.
items
():
dataset
.
run_stat
()
log
.
bar
()
log
.
writeline
(
f
'
{
name
}
distribution:'
)
log
.
statistic_write
(
dataset
.
statistics
)
log
.
format_k_v
(
'# atoms (node)'
,
dataset
.
natoms
,
write
=
True
)
log
.
format_k_v
(
'# structures (graph)'
,
len
(
dataset
),
write
=
True
)
chem_species
.
update
(
dataset
.
species
)
log
.
bar
()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys
=
[
KEY
.
CHEMICAL_SPECIES
,
KEY
.
NUM_SPECIES
,
KEY
.
TYPE_MAP
]
if
all
([
config
[
ck
]
==
'auto'
for
ck
in
chem_keys
]):
# see parse_input.py
log
.
writeline
(
'Known species are obtained from the dataset'
)
config
.
update
(
util
.
chemical_species_preprocess
(
sorted
(
list
(
chem_species
))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats
=
[
KEY
.
SHIFT
,
KEY
.
SCALE
,
KEY
.
CONV_DENOMINATOR
]
for
k
in
init_from_stats
:
input
=
config
[
k
]
# statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if
isinstance
(
input
,
str
)
and
hasattr
(
train_set
,
input
):
var
=
getattr
(
train_set
,
input
)
config
.
update
({
k
:
var
})
log
.
writeline
(
f
'
{
k
}
is obtained from statistics'
)
elif
isinstance
(
input
,
str
)
and
not
hasattr
(
train_set
,
input
):
raise
NotImplementedError
(
input
)
return
datasets
mace-bench/3rdparty/SevenNet/sevenn/train/collate.py
0 → 100644
View file @
ca86f720
from
typing
import
Any
,
List
,
Optional
,
Sequence
from
ase.atoms
import
Atoms
from
torch_geometric.loader.dataloader
import
Collater
from
sevenn.atom_graph_data
import
AtomGraphData
from
.dataload
import
atoms_to_graph
class
AtomsToGraphCollater
(
Collater
):
def
__init__
(
self
,
dataset
:
Sequence
[
Atoms
],
cutoff
:
float
,
transfer_info
:
bool
=
False
,
follow_batch
:
Optional
[
List
[
str
]]
=
None
,
exclude_keys
:
Optional
[
List
[
str
]]
=
None
,
y_from_calc
:
bool
=
True
,
):
# quite original collator's type mismatch with []
super
().
__init__
([],
follow_batch
,
exclude_keys
)
self
.
dataset
=
dataset
self
.
cutoff
=
cutoff
self
.
transfer_info
=
transfer_info
self
.
y_from_calc
=
y_from_calc
def
__call__
(
self
,
batch
:
List
[
Any
])
->
Any
:
# build list of graph
graph_list
=
[]
for
stct
in
batch
:
graph
=
atoms_to_graph
(
stct
,
self
.
cutoff
,
transfer_info
=
self
.
transfer_info
,
y_from_calc
=
self
.
y_from_calc
,
)
graph
=
AtomGraphData
.
from_numpy_dict
(
graph
)
graph_list
.
append
(
graph
)
return
super
().
__call__
(
graph_list
)
mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py
0 → 100644
View file @
ca86f720
import
copy
import
os.path
from
functools
import
partial
from
itertools
import
chain
,
islice
from
typing
import
Callable
,
Dict
,
List
,
Optional
import
ase
import
ase.io
import
numpy
as
np
import
torch.multiprocessing
as
mp
from
ase.io.vasp_parsers.vasp_outcar_parsers
import
(
Cell
,
DefaultParsersContainer
,
Energy
,
OutcarChunkParser
,
PositionsAndForces
,
Stress
,
outcarchunks
,
)
from
ase.neighborlist
import
primitive_neighbor_list
from
ase.utils
import
string2index
from
braceexpand
import
braceexpand
from
tqdm
import
tqdm
import
sevenn._keys
as
KEY
from
sevenn._const
import
LossType
from
sevenn.atom_graph_data
import
AtomGraphData
from
.dataset
import
AtomGraphDataset
def
_graph_build_matscipy
(
cutoff
:
float
,
pbc
,
cell
,
pos
):
pbc_x
=
pbc
[
0
]
pbc_y
=
pbc
[
1
]
pbc_z
=
pbc
[
2
]
identity
=
np
.
identity
(
3
,
dtype
=
float
)
max_positions
=
np
.
max
(
np
.
absolute
(
pos
))
+
1
# Extend cell in non-periodic directions
# For models with more than 5 layers,
# the multiplicative constant needs to be increased.
if
not
pbc_x
:
cell
[
0
,
:]
=
max_positions
*
5
*
cutoff
*
identity
[
0
,
:]
if
not
pbc_y
:
cell
[
1
,
:]
=
max_positions
*
5
*
cutoff
*
identity
[
1
,
:]
if
not
pbc_z
:
cell
[
2
,
:]
=
max_positions
*
5
*
cutoff
*
identity
[
2
,
:]
# it does not have self-interaction
edge_src
,
edge_dst
,
edge_vec
,
shifts
=
neighbour_list
(
quantities
=
'ijDS'
,
pbc
=
pbc
,
cell
=
cell
,
positions
=
pos
,
cutoff
=
cutoff
,
)
# dtype issue
edge_src
=
edge_src
.
astype
(
np
.
int64
)
edge_dst
=
edge_dst
.
astype
(
np
.
int64
)
return
edge_src
,
edge_dst
,
edge_vec
,
shifts
def
_graph_build_ase
(
cutoff
:
float
,
pbc
,
cell
,
pos
):
# building neighbor list
edge_src
,
edge_dst
,
edge_vec
,
shifts
=
primitive_neighbor_list
(
'ijDS'
,
pbc
,
cell
,
pos
,
cutoff
,
self_interaction
=
True
)
is_zero_idx
=
np
.
all
(
edge_vec
==
0
,
axis
=
1
)
is_self_idx
=
edge_src
==
edge_dst
non_trivials
=
~
(
is_zero_idx
&
is_self_idx
)
shifts
=
np
.
array
(
shifts
[
non_trivials
])
edge_vec
=
edge_vec
[
non_trivials
]
edge_src
=
edge_src
[
non_trivials
]
edge_dst
=
edge_dst
[
non_trivials
]
return
edge_src
,
edge_dst
,
edge_vec
,
shifts
_graph_build_f
=
_graph_build_ase
try
:
from
matscipy.neighbours
import
neighbour_list
_graph_build_f
=
_graph_build_matscipy
except
ImportError
:
pass
def
_correct_scalar
(
v
):
if
isinstance
(
v
,
np
.
ndarray
):
v
=
v
.
squeeze
()
assert
v
.
ndim
==
0
,
f
'given
{
v
}
is not a scalar'
return
v
elif
isinstance
(
v
,
(
int
,
float
,
np
.
integer
,
np
.
floating
)):
return
np
.
array
(
v
)
else
:
assert
False
,
f
'
{
type
(
v
)
}
is not expected'
def
unlabeled_atoms_to_graph
(
atoms
:
ase
.
Atoms
,
cutoff
:
float
):
pos
=
atoms
.
get_positions
()
cell
=
np
.
array
(
atoms
.
get_cell
())
pbc
=
atoms
.
get_pbc
()
edge_src
,
edge_dst
,
edge_vec
,
shifts
=
_graph_build_f
(
cutoff
,
pbc
,
cell
,
pos
)
edge_idx
=
np
.
array
([
edge_src
,
edge_dst
])
atomic_numbers
=
atoms
.
get_atomic_numbers
()
cell
=
np
.
array
(
cell
)
vol
=
_correct_scalar
(
atoms
.
cell
.
volume
)
if
vol
==
0
:
vol
=
np
.
array
(
np
.
finfo
(
float
).
eps
)
data
=
{
KEY
.
NODE_FEATURE
:
atomic_numbers
,
KEY
.
ATOMIC_NUMBERS
:
atomic_numbers
,
KEY
.
POS
:
pos
,
KEY
.
EDGE_IDX
:
edge_idx
,
KEY
.
EDGE_VEC
:
edge_vec
,
KEY
.
CELL
:
cell
,
KEY
.
CELL_SHIFT
:
shifts
,
KEY
.
CELL_VOLUME
:
vol
,
KEY
.
NUM_ATOMS
:
_correct_scalar
(
len
(
atomic_numbers
)),
}
data
[
KEY
.
INFO
]
=
{}
return
data
def
atoms_to_graph
(
atoms
:
ase
.
Atoms
,
cutoff
:
float
,
transfer_info
:
bool
=
True
,
y_from_calc
:
bool
=
False
,
allow_unlabeled
:
bool
=
False
,
):
"""
From ase atoms, return AtomGraphData as graph based on cutoff radius
Except for energy, force and stress labels must be numpy array type
as other cases are not tested.
Returns 'np.nan' with consistent shape for unlabeled data
(ex. stress of non-pbc system)
Args:
atoms (Atoms): ase atoms
cutoff (float): cutoff radius
transfer_info (bool): if True, transfer ".info" from atoms to graph,
defaults to True
y_from_calc: if True, get ref values from calculator, defaults to False
Returns:
numpy dict that can be used to initialize AtomGraphData
by AtomGraphData(**atoms_to_graph(atoms, cutoff))
, for scalar, its shape is (), and types are np.ndarray
Requires grad is handled by 'dataset' not here.
"""
if
not
y_from_calc
:
y_energy
=
atoms
.
info
[
'y_energy'
]
y_force
=
atoms
.
arrays
[
'y_force'
]
y_stress
=
atoms
.
info
.
get
(
'y_stress'
,
np
.
full
((
6
,),
np
.
nan
))
if
y_stress
.
shape
==
(
3
,
3
):
y_stress
=
np
.
array
(
[
y_stress
[
0
][
0
],
y_stress
[
1
][
1
],
y_stress
[
2
][
2
],
y_stress
[
0
][
1
],
y_stress
[
1
][
2
],
y_stress
[
2
][
0
],
]
)
else
:
y_stress
=
y_stress
.
squeeze
()
else
:
from_calc
=
_y_from_calc
(
atoms
)
y_energy
=
from_calc
[
'energy'
]
y_force
=
from_calc
[
'force'
]
y_stress
=
from_calc
[
'stress'
]
assert
y_stress
.
shape
==
(
6
,),
'If you see this, please raise a issue'
if
not
allow_unlabeled
and
(
np
.
isnan
(
y_energy
)
or
np
.
isnan
(
y_force
).
any
()):
raise
ValueError
(
'Unlabeled E or F found, set allow_unlabeled True'
)
pos
=
atoms
.
get_positions
()
cell
=
np
.
array
(
atoms
.
get_cell
())
pbc
=
atoms
.
get_pbc
()
edge_src
,
edge_dst
,
edge_vec
,
shifts
=
_graph_build_f
(
cutoff
,
pbc
,
cell
,
pos
)
edge_idx
=
np
.
array
([
edge_src
,
edge_dst
])
atomic_numbers
=
atoms
.
get_atomic_numbers
()
cell
=
np
.
array
(
cell
)
vol
=
_correct_scalar
(
atoms
.
cell
.
volume
)
if
vol
==
0
:
vol
=
np
.
array
(
np
.
finfo
(
float
).
eps
)
data
=
{
KEY
.
NODE_FEATURE
:
atomic_numbers
,
KEY
.
ATOMIC_NUMBERS
:
atomic_numbers
,
KEY
.
POS
:
pos
,
KEY
.
EDGE_IDX
:
edge_idx
,
KEY
.
EDGE_VEC
:
edge_vec
,
KEY
.
ENERGY
:
_correct_scalar
(
y_energy
),
KEY
.
FORCE
:
y_force
,
KEY
.
STRESS
:
y_stress
.
reshape
(
1
,
6
),
# to make batch have (n_node, 6)
KEY
.
CELL
:
cell
,
KEY
.
CELL_SHIFT
:
shifts
,
KEY
.
CELL_VOLUME
:
vol
,
KEY
.
NUM_ATOMS
:
_correct_scalar
(
len
(
atomic_numbers
)),
KEY
.
PER_ATOM_ENERGY
:
_correct_scalar
(
y_energy
/
len
(
pos
)),
}
if
transfer_info
and
atoms
.
info
is
not
None
:
info
=
copy
.
deepcopy
(
atoms
.
info
)
# save only metadata
info
.
pop
(
'y_energy'
,
None
)
info
.
pop
(
'y_force'
,
None
)
info
.
pop
(
'y_stress'
,
None
)
data
[
KEY
.
INFO
]
=
info
else
:
data
[
KEY
.
INFO
]
=
{}
return
data
def
graph_build
(
atoms_list
:
List
,
cutoff
:
float
,
num_cores
:
int
=
1
,
transfer_info
:
bool
=
True
,
y_from_calc
:
bool
=
False
,
allow_unlabeled
:
bool
=
False
,
)
->
List
[
AtomGraphData
]:
"""
parallel version of graph_build
build graph from atoms_list and return list of AtomGraphData
Args:
atoms_list (List): list of ASE atoms
cutoff (float): cutoff radius of graph
num_cores (int): number of cores to use
transfer_info (bool): if True, copy info from atoms to graph,
defaults to True
y_from_calc (bool): Get reference y labels from calculator, defaults to False
Returns:
List[AtomGraphData]: list of AtomGraphData
"""
serial
=
num_cores
==
1
inputs
=
[
(
atoms
,
cutoff
,
transfer_info
,
y_from_calc
,
allow_unlabeled
)
for
atoms
in
atoms_list
]
if
not
serial
:
pool
=
mp
.
Pool
(
num_cores
)
graph_list
=
pool
.
starmap
(
atoms_to_graph
,
tqdm
(
inputs
,
total
=
len
(
atoms_list
),
desc
=
f
'graph_build (
{
num_cores
}
)'
),
)
pool
.
close
()
pool
.
join
()
else
:
graph_list
=
[
atoms_to_graph
(
*
input_
)
for
input_
in
tqdm
(
inputs
,
desc
=
'graph_build (1)'
)
]
graph_list
=
[
AtomGraphData
.
from_numpy_dict
(
g
)
for
g
in
graph_list
]
return
graph_list
def
_y_from_calc
(
atoms
:
ase
.
Atoms
):
ret
=
{
'energy'
:
np
.
nan
,
'force'
:
np
.
full
((
len
(
atoms
),
3
),
np
.
nan
),
'stress'
:
np
.
full
((
6
,),
np
.
nan
),
}
if
atoms
.
calc
is
None
:
return
ret
try
:
ret
[
'energy'
]
=
atoms
.
get_potential_energy
(
force_consistent
=
True
)
except
NotImplementedError
:
ret
[
'energy'
]
=
atoms
.
get_potential_energy
()
try
:
ret
[
'force'
]
=
atoms
.
get_forces
(
apply_constraint
=
False
)
except
NotImplementedError
:
pass
try
:
y_stress
=
-
1
*
atoms
.
get_stress
()
# it ensures correct shape
ret
[
'stress'
]
=
np
.
array
(
y_stress
[[
0
,
1
,
2
,
5
,
3
,
4
]])
except
RuntimeError
:
pass
return
ret
def
_set_atoms_y
(
atoms_list
:
List
[
ase
.
Atoms
],
energy_key
:
Optional
[
str
]
=
None
,
force_key
:
Optional
[
str
]
=
None
,
stress_key
:
Optional
[
str
]
=
None
,
)
->
List
[
ase
.
Atoms
]:
"""
Define how SevenNet reads ASE.atoms object for its y label
If energy_key, force_key, or stress_key is given, the corresponding
label is obtained from .info dict of Atoms object. These values should
have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress,
respectively. (stress in Voigt notation)
Args:
atoms_list (list[ase.Atoms]): target atoms to set y_labels
energy_key (str, optional): key to get energy. Defaults to None.
force_key (str, optional): key to get force. Defaults to None.
stress_key (str, optional): key to get stress. Defaults to None.
Returns:
list[ase.Atoms]: list of ase.Atoms
Raises:
RuntimeError: if ase atoms are somewhat imperfect
Use free_energy: atoms.get_potential_energy(force_consistent=True)
If it is not available, use atoms.get_potential_energy()
If stress is available, initialize stress tensor
Ignore constraints like selective dynamics
"""
for
atoms
in
atoms_list
:
from_calc
=
_y_from_calc
(
atoms
)
if
energy_key
is
not
None
:
atoms
.
info
[
'y_energy'
]
=
atoms
.
info
.
pop
(
energy_key
)
else
:
atoms
.
info
[
'y_energy'
]
=
from_calc
[
'energy'
]
if
force_key
is
not
None
:
atoms
.
arrays
[
'y_force'
]
=
atoms
.
arrays
.
pop
(
force_key
)
else
:
atoms
.
arrays
[
'y_force'
]
=
from_calc
[
'force'
]
if
stress_key
is
not
None
:
y_stress
=
-
1
*
atoms
.
info
.
pop
(
stress_key
)
atoms
.
info
[
'y_stress'
]
=
np
.
array
(
y_stress
[[
0
,
1
,
2
,
5
,
3
,
4
]])
else
:
atoms
.
info
[
'y_stress'
]
=
from_calc
[
'stress'
]
return
atoms_list
def
ase_reader
(
filename
:
str
,
energy_key
:
Optional
[
str
]
=
None
,
force_key
:
Optional
[
str
]
=
None
,
stress_key
:
Optional
[
str
]
=
None
,
index
:
str
=
':'
,
**
kwargs
,
)
->
List
[
ase
.
Atoms
]:
"""
Wrapper of ase.io.read
"""
atoms_list
=
ase
.
io
.
read
(
filename
,
index
=
index
,
**
kwargs
)
if
not
isinstance
(
atoms_list
,
list
):
atoms_list
=
[
atoms_list
]
return
_set_atoms_y
(
atoms_list
,
energy_key
,
force_key
,
stress_key
)
# Reader
def
structure_list_reader
(
filename
:
str
,
format_outputs
:
Optional
[
str
]
=
None
):
"""
Read from structure_list using braceexpand and ASE
Args:
fname : filename of structure_list
Returns:
dictionary of lists of ASE structures.
key is title of training data (user-define)
"""
parsers
=
DefaultParsersContainer
(
PositionsAndForces
,
Stress
,
Energy
,
Cell
).
make_parsers
()
ocp
=
OutcarChunkParser
(
parsers
=
parsers
)
def
parse_label
(
line
):
line
=
line
.
strip
()
if
line
.
startswith
(
'['
)
is
False
:
return
False
elif
line
.
endswith
(
']'
)
is
False
:
raise
ValueError
(
'wrong structure_list title format'
)
return
line
[
1
:
-
1
]
def
parse_fileline
(
line
):
line
=
line
.
strip
().
split
()
if
len
(
line
)
==
1
:
line
.
append
(
':'
)
elif
len
(
line
)
!=
2
:
raise
ValueError
(
'wrong structure_list format'
)
return
line
[
0
],
line
[
1
]
structure_list_file
=
open
(
filename
,
'r'
)
lines
=
structure_list_file
.
readlines
()
raw_str_dict
=
{}
label
=
'Default'
for
line
in
lines
:
if
line
.
strip
()
==
''
:
continue
tmp_label
=
parse_label
(
line
)
if
tmp_label
:
label
=
tmp_label
raw_str_dict
[
label
]
=
[]
continue
elif
label
in
raw_str_dict
:
files_expr
,
index_expr
=
parse_fileline
(
line
)
raw_str_dict
[
label
].
append
((
files_expr
,
index_expr
))
else
:
raise
ValueError
(
'wrong structure_list format'
)
structure_list_file
.
close
()
structures_dict
=
{}
info_dct
=
{
'data_from'
:
'user_OUTCAR'
}
for
title
,
file_lines
in
raw_str_dict
.
items
():
stct_lists
=
[]
for
file_line
in
file_lines
:
files_expr
,
index_expr
=
file_line
index
=
string2index
(
index_expr
)
for
expanded_filename
in
list
(
braceexpand
(
files_expr
)):
f_stream
=
open
(
expanded_filename
,
'r'
)
# generator of all outcar ionic steps
gen_all
=
outcarchunks
(
f_stream
,
ocp
)
try
:
# TODO: index may not slice, it can be integer
it_atoms
=
islice
(
gen_all
,
index
.
start
,
index
.
stop
,
index
.
step
)
except
ValueError
:
# TODO: support
# negative index
raise
ValueError
(
'Negative index is not supported yet'
)
info_dct_f
=
{
**
info_dct
,
'file'
:
os
.
path
.
abspath
(
expanded_filename
),
}
for
idx
,
o
in
enumerate
(
it_atoms
):
try
:
it_atoms
=
islice
(
gen_all
,
index
.
start
,
index
.
stop
,
index
.
step
)
except
ValueError
:
# TODO: support
# negative index
raise
ValueError
(
'Negative index is not supported yet'
)
info_dct_f
=
{
**
info_dct
,
'file'
:
os
.
path
.
abspath
(
expanded_filename
),
}
for
idx
,
o
in
enumerate
(
it_atoms
):
try
:
istep
=
index
.
start
+
idx
*
index
.
step
# type: ignore
atoms
=
o
.
build
()
atoms
.
info
=
{
**
info_dct_f
,
'ionic_step'
:
istep
}.
copy
()
except
TypeError
:
# it is not slice of ionic steps
atoms
=
o
.
build
()
atoms
.
info
=
info_dct_f
.
copy
()
stct_lists
.
append
(
atoms
)
f_stream
.
close
()
else
:
stct_lists
+=
ase
.
io
.
read
(
expanded_filename
,
index
=
index_expr
,
parallel
=
False
,
)
structures_dict
[
title
]
=
stct_lists
return
{
k
:
_set_atoms_y
(
v
)
for
k
,
v
in
structures_dict
.
items
()}
def
dict_reader
(
data_dict
:
Dict
):
data_dict_cp
=
copy
.
deepcopy
(
data_dict
)
ret
=
[]
file_list
=
data_dict_cp
.
pop
(
'file_list'
,
None
)
if
file_list
is
None
:
raise
KeyError
(
'file_list is not found'
)
data_weight_default
=
{
'energy'
:
1.0
,
'force'
:
1.0
,
'stress'
:
1.0
,
}
data_weight
=
data_weight_default
.
copy
()
data_weight
.
update
(
data_dict_cp
.
pop
(
KEY
.
DATA_WEIGHT
,
{}))
for
file_dct
in
file_list
:
ftype
=
file_dct
.
pop
(
'data_format'
,
'ase'
)
files
=
list
(
braceexpand
(
file_dct
.
pop
(
'file'
)))
if
ftype
==
'ase'
:
ret
.
extend
(
chain
(
*
[
ase_reader
(
f
,
**
file_dct
)
for
f
in
files
]))
elif
ftype
==
'graph'
:
continue
else
:
raise
ValueError
(
f
'
{
ftype
}
yet'
)
for
atoms
in
ret
:
atoms
.
info
.
update
(
data_dict_cp
)
atoms
.
info
.
update
({
KEY
.
DATA_WEIGHT
:
data_weight
})
return
_set_atoms_y
(
ret
)
def
match_reader
(
reader_name
:
str
,
**
kwargs
):
reader
=
None
metadata
=
{}
if
reader_name
==
'structure_list'
:
reader
=
partial
(
structure_list_reader
,
**
kwargs
)
metadata
.
update
({
'origin'
:
'structure_list'
})
else
:
reader
=
partial
(
ase_reader
,
**
kwargs
)
metadata
.
update
({
'origin'
:
'ase_reader'
})
return
reader
,
metadata
def
file_to_dataset
(
file
:
str
,
cutoff
:
float
,
cores
:
int
=
1
,
reader
:
Callable
=
ase_reader
,
label
:
Optional
[
str
]
=
None
,
transfer_info
:
bool
=
True
,
use_weight
:
bool
=
False
,
use_modality
:
bool
=
False
,
):
"""
Deprecated
Read file by reader > get list of atoms or dict of atoms
"""
# expect label: atoms_list dct or atoms or list of atoms
atoms
=
reader
(
file
)
if
type
(
atoms
)
is
list
:
if
label
is
None
:
label
=
KEY
.
LABEL_NONE
atoms_dct
=
{
label
:
atoms
}
elif
isinstance
(
atoms
,
ase
.
Atoms
):
if
label
is
None
:
label
=
KEY
.
LABEL_NONE
atoms_dct
=
{
label
:
[
atoms
]}
elif
isinstance
(
atoms
,
dict
):
atoms_dct
=
atoms
else
:
raise
TypeError
(
'The return of reader is not list or dict'
)
graph_dct
=
{}
for
label
,
atoms_list
in
atoms_dct
.
items
():
graph_list
=
graph_build
(
atoms_list
=
atoms_list
,
cutoff
=
cutoff
,
num_cores
=
cores
,
transfer_info
=
transfer_info
,
y_from_calc
=
False
,
)
label_info
=
label
.
split
(
':'
)
for
graph
in
graph_list
:
graph
[
KEY
.
USER_LABEL
]
=
label_info
[
0
].
strip
()
if
use_weight
:
find_weight
=
False
for
info
in
label_info
[
1
:]:
if
'w='
in
info
.
lower
():
weights
=
info
.
split
(
'='
)[
1
]
try
:
if
','
in
weights
:
weight_list
=
list
(
map
(
float
,
weights
.
split
(
','
)))
else
:
weight_list
=
[
float
(
weights
)]
*
3
weight_dict
=
{}
for
idx
,
loss_type
in
enumerate
(
LossType
):
weight_dict
[
loss_type
.
value
]
=
(
weight_list
[
idx
]
if
idx
<
len
(
weight_list
)
else
1
)
graph
[
KEY
.
DATA_WEIGHT
]
=
weight_dict
find_weight
=
True
break
except
:
raise
ValueError
(
'Weight must be a real number, but'
f
'
{
weights
}
is given for
{
label
}
'
)
if
not
find_weight
:
weight_dict
=
{}
for
loss_type
in
LossType
:
weight_dict
[
loss_type
.
value
]
=
1
graph
[
KEY
.
DATA_WEIGHT
]
=
weight_dict
if
use_modality
:
find_modality
=
False
for
info
in
label_info
[
1
:]:
if
'm='
in
info
.
lower
():
graph
[
KEY
.
DATA_MODALITY
]
=
(
info
.
split
(
'='
)[
1
]).
strip
()
find_modality
=
True
break
if
not
find_modality
:
raise
ValueError
(
f
'Modality not given for
{
label
}
'
)
graph_dct
[
label_info
[
0
].
strip
()]
=
graph_list
db
=
AtomGraphDataset
(
graph_dct
,
cutoff
)
return
db
mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py
0 → 100644
View file @
ca86f720
import
itertools
import
random
from
collections
import
Counter
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
from
ase.data
import
chemical_symbols
from
sklearn.linear_model
import
Ridge
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
class
AtomGraphDataset
:
"""
Deprecated
class representing dataset of AtomGraphData
the dataset is handled as dict, {label: data}
if given data is List, it stores data as {KEY_DEFAULT: data}
cutoff is for metadata of the graphs not used for some calc
Every data expected to have one unique cutoff
No validity or check of the condition is done inside the object
attribute:
dataset (Dict[str, List]): key is data label(str), value is list of data
user_labels (List[str]): list of user labels same as dataset.keys()
meta (Dict, Optional): metadata of dataset
for now, metadata 'might' have following keys:
KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict)
"""
DATA_KEY_X
=
(
KEY
.
NODE_FEATURE
)
# atomic_number > one_hot_idx > one_hot_vector
DATA_KEY_ENERGY
=
KEY
.
ENERGY
DATA_KEY_FORCE
=
KEY
.
FORCE
KEY_DEFAULT
=
KEY
.
LABEL_NONE
def
__init__
(
self
,
dataset
:
Union
[
Dict
[
str
,
List
],
List
],
cutoff
:
float
,
metadata
:
Optional
[
Dict
]
=
None
,
x_is_one_hot_idx
:
bool
=
False
,
):
"""
Default constructor of AtomGraphDataset
Args:
dataset (Union[Dict[str, List], List]: dataset as dict or pure list
metadata (Dict, Optional): metadata of data
cutoff (float): cutoff radius of graphs inside the dataset
x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z'
'x' (node feature) of dataset can have 3 states, atomic_numbers,
one_hot_idx, or one_hot_vector.
atomic_numbers is general but cannot directly used for input
one_hot_idx is can be input of the model but requires 'type_map'
"""
self
.
cutoff
=
cutoff
self
.
x_is_one_hot_idx
=
x_is_one_hot_idx
if
metadata
is
None
:
metadata
=
{
KEY
.
CUTOFF
:
cutoff
}
self
.
meta
=
metadata
if
type
(
dataset
)
is
list
:
self
.
dataset
=
{
self
.
KEY_DEFAULT
:
dataset
}
else
:
self
.
dataset
=
dataset
self
.
user_labels
=
list
(
self
.
dataset
.
keys
())
# group_by_key here? or not?
def
rewrite_labels_to_data
(
self
):
"""
Based on self.dataset dict's keys
write data[KEY.USER_LABEL] to correspond to dict's keys
Most of times, it is already correctly written
But required to rewrite if someone rearrange dataset by their own way
"""
for
label
,
data_list
in
self
.
dataset
.
items
():
for
data
in
data_list
:
data
[
KEY
.
USER_LABEL
]
=
label
def
group_by_key
(
self
,
data_key
:
str
=
KEY
.
USER_LABEL
):
"""
group dataset list by given key and save it as dict
and change in-place
Args:
data_key (str): data key to group by
original use is USER_LABEL, but it can be used for other keys
if someone established it from data[KEY.INFO]
"""
data_list
=
self
.
to_list
()
self
.
dataset
=
{}
for
datum
in
data_list
:
key
=
datum
[
data_key
]
if
key
not
in
self
.
dataset
:
self
.
dataset
[
key
]
=
[]
self
.
dataset
[
key
].
append
(
datum
)
self
.
user_labels
=
list
(
self
.
dataset
.
keys
())
def
separate_info
(
self
,
data_key
:
str
=
KEY
.
INFO
):
"""
Separate info from data and save it as list of dict
to make it compatible with torch_geometric and later training
"""
data_list
=
self
.
to_list
()
info_list
=
[]
for
datum
in
data_list
:
if
data_key
in
datum
is
False
:
continue
info_list
.
append
(
datum
[
data_key
])
del
datum
[
data_key
]
# It does change the self.dataset
datum
[
data_key
]
=
len
(
info_list
)
-
1
self
.
info_list
=
info_list
return
(
data_list
,
info_list
)
def
get_species
(
self
):
"""
You can also use get_natoms and extract keys from there instead of this
(And it is more efficient)
get chemical species of dataset
return list of SORTED chemical species (as str)
"""
if
hasattr
(
self
,
'type_map'
):
natoms
=
self
.
get_natoms
(
self
.
type_map
)
else
:
natoms
=
self
.
get_natoms
()
species
=
set
()
for
natom_dct
in
natoms
.
values
():
species
.
update
(
natom_dct
.
keys
())
species
=
sorted
(
list
(
species
))
return
species
def
get_modalities
(
self
):
modalities
=
set
()
for
data_list
in
self
.
dataset
.
values
():
datum
=
data_list
[
0
].
to_dict
()
if
KEY
.
DATA_MODALITY
in
datum
.
keys
():
modalities
.
add
(
datum
[
KEY
.
DATA_MODALITY
])
else
:
return
[]
return
list
(
modalities
)
def
write_modal_attr
(
self
,
modal_type_mapper
:
dict
,
write_modal_type
:
bool
=
False
):
num_modalities
=
len
(
modal_type_mapper
)
for
data_list
in
self
.
dataset
.
values
():
for
data
in
data_list
:
tmp_tensor
=
torch
.
zeros
(
num_modalities
)
if
data
[
KEY
.
DATA_MODALITY
]
!=
'common'
:
modal_idx
=
modal_type_mapper
[
data
[
KEY
.
DATA_MODALITY
]]
tmp_tensor
[
modal_idx
]
=
1.0
if
write_modal_type
:
data
[
KEY
.
MODAL_TYPE
]
=
modal_idx
data
[
KEY
.
MODAL_ATTR
]
=
tmp_tensor
def
get_dict_sort_by_modality
(
self
):
dict_sort_by_modality
=
{}
for
data_list
in
self
.
dataset
.
values
():
try
:
modal_key
=
data_list
[
0
].
to_dict
()[
KEY
.
DATA_MODALITY
]
except
:
# Dataset is not modal
raise
ValueError
(
'This dataset has no modality.'
)
if
modal_key
not
in
dict_sort_by_modality
.
keys
():
dict_sort_by_modality
[
modal_key
]
=
[]
dict_sort_by_modality
[
modal_key
].
extend
(
data_list
)
return
dict_sort_by_modality
def
len
(
self
):
if
(
len
(
self
.
dataset
.
keys
())
==
1
and
list
(
self
.
dataset
.
keys
())[
0
]
==
AtomGraphDataset
.
KEY_DEFAULT
):
return
len
(
self
.
dataset
[
AtomGraphDataset
.
KEY_DEFAULT
])
else
:
return
{
k
:
len
(
v
)
for
k
,
v
in
self
.
dataset
.
items
()}
def
get
(
self
,
idx
:
int
,
key
:
Optional
[
str
]
=
None
):
if
key
is
None
:
key
=
self
.
KEY_DEFAULT
return
self
.
dataset
[
key
][
idx
]
def
items
(
self
):
return
self
.
dataset
.
items
()
def
to_dict
(
self
):
dct_dataset
=
{}
for
label
,
data_list
in
self
.
dataset
.
items
():
dct_dataset
[
label
]
=
[
datum
.
to_dict
()
for
datum
in
data_list
]
self
.
dataset
=
dct_dataset
return
self
def
x_to_one_hot_idx
(
self
,
type_map
:
Dict
[
int
,
int
]):
"""
type_map is dict of {atomic_number: one_hot_idx}
after this process, the dataset has dependency on type_map
or chemical species user want to consider
"""
assert
self
.
x_is_one_hot_idx
is
False
for
data_list
in
self
.
dataset
.
values
():
for
datum
in
data_list
:
datum
[
self
.
DATA_KEY_X
]
=
torch
.
LongTensor
(
[
type_map
[
z
.
item
()]
for
z
in
datum
[
self
.
DATA_KEY_X
]]
)
self
.
type_map
=
type_map
self
.
x_is_one_hot_idx
=
True
def
toggle_requires_grad_of_data
(
self
,
key
:
str
,
requires_grad_value
:
bool
):
"""
set requires_grad of specific key of data(pos, edge_vec, ...)
"""
for
data_list
in
self
.
dataset
.
values
():
for
datum
in
data_list
:
datum
[
key
].
requires_grad_
(
requires_grad_value
)
def
divide_dataset
(
self
,
ratio
:
float
,
constant_ratio_btw_labels
:
bool
=
True
,
ignore_test
:
bool
=
True
):
"""
divide dataset into 1-2*ratio : ratio : ratio
return divided AtomGraphDataset
returned value lost its dict key and became {KEY_DEFAULT: datalist}
but KEY.USER_LABEL of each data is preserved
"""
def
divide
(
ratio
:
float
,
data_list
:
List
,
ignore_test
=
True
):
if
ratio
>
0.5
:
raise
ValueError
(
'Ratio must not exceed 0.5'
)
data_len
=
len
(
data_list
)
random
.
shuffle
(
data_list
)
n_validation
=
int
(
data_len
*
ratio
)
if
n_validation
==
0
:
raise
ValueError
(
'# of validation set is 0, increase your dataset'
)
if
ignore_test
:
test_list
=
[]
n_train
=
data_len
-
n_validation
train_list
=
data_list
[
0
:
n_train
]
valid_list
=
data_list
[
n_train
:]
else
:
n_train
=
data_len
-
2
*
n_validation
train_list
=
data_list
[
0
:
n_train
]
valid_list
=
data_list
[
n_train
:
n_train
+
n_validation
]
test_list
=
data_list
[
n_train
+
n_validation
:
data_len
]
return
train_list
,
valid_list
,
test_list
lists
=
([],
[],
[])
# train, valid, test
if
constant_ratio_btw_labels
:
for
data_list
in
self
.
dataset
.
values
():
for
store
,
divided
in
zip
(
lists
,
divide
(
ratio
,
data_list
)):
store
.
extend
(
divided
)
else
:
lists
=
divide
(
ratio
,
self
.
to_list
())
dbs
=
tuple
(
AtomGraphDataset
(
data
,
self
.
cutoff
,
self
.
meta
)
for
data
in
lists
)
for
db
in
dbs
:
db
.
group_by_key
()
return
dbs
def
to_list
(
self
):
return
list
(
itertools
.
chain
(
*
self
.
dataset
.
values
()))
def
get_natoms
(
self
,
type_map
:
Optional
[
Dict
[
int
,
int
]]
=
None
):
"""
if x_is_one_hot_idx, type_map is required
type_map: Z->one_hot_index(node_feature)
return Dict{label: {symbol, natom}]}
"""
assert
not
(
self
.
x_is_one_hot_idx
is
True
and
type_map
is
None
)
natoms
=
{}
for
label
,
data
in
self
.
dataset
.
items
():
natoms
[
label
]
=
Counter
()
for
datum
in
data
:
if
self
.
x_is_one_hot_idx
and
type_map
is
not
None
:
Zs
=
util
.
onehot_to_chem
(
datum
[
self
.
DATA_KEY_X
],
type_map
)
else
:
Zs
=
[
chemical_symbols
[
z
]
for
z
in
datum
[
self
.
DATA_KEY_X
].
tolist
()
]
cnt
=
Counter
(
Zs
)
natoms
[
label
]
+=
cnt
natoms
[
label
]
=
dict
(
natoms
[
label
])
return
natoms
def
get_per_atom_mean
(
self
,
key
:
str
,
key_num_atoms
:
str
=
KEY
.
NUM_ATOMS
):
"""
return per_atom mean of given data key
"""
eng_list
=
torch
.
Tensor
(
[
x
[
key
]
/
x
[
key_num_atoms
]
for
x
in
self
.
to_list
()]
)
return
float
(
torch
.
mean
(
eng_list
))
def
get_per_atom_energy_mean
(
self
):
"""
alias for get_per_atom_mean(KEY.ENERGY)
"""
return
self
.
get_per_atom_mean
(
self
.
DATA_KEY_ENERGY
)
def
get_species_ref_energy_by_linear_comb
(
self
,
num_chem_species
:
int
):
"""
Total energy as y, composition as c_i,
solve linear regression of y = c_i*X
sklearn LinearRegression as solver
x should be one-hot-indexed
give num_chem_species if possible
"""
assert
self
.
x_is_one_hot_idx
is
True
data_list
=
self
.
to_list
()
c
=
torch
.
zeros
((
len
(
data_list
),
num_chem_species
))
for
idx
,
datum
in
enumerate
(
data_list
):
c
[
idx
]
=
torch
.
bincount
(
datum
[
self
.
DATA_KEY_X
],
minlength
=
num_chem_species
)
y
=
torch
.
Tensor
([
x
[
self
.
DATA_KEY_ENERGY
]
for
x
in
data_list
])
c
=
c
.
numpy
()
y
=
y
.
numpy
()
# tweak to fine tune training from many-element to small element
zero_indices
=
np
.
all
(
c
==
0
,
axis
=
0
)
c_reduced
=
c
[:,
~
zero_indices
]
full_coeff
=
np
.
zeros
(
num_chem_species
)
coef_reduced
=
(
Ridge
(
alpha
=
0.1
,
fit_intercept
=
False
).
fit
(
c_reduced
,
y
).
coef_
)
full_coeff
[
~
zero_indices
]
=
coef_reduced
return
full_coeff
def
get_force_rms
(
self
):
force_list
=
[]
for
x
in
self
.
to_list
():
force_list
.
extend
(
x
[
self
.
DATA_KEY_FORCE
]
.
reshape
(
-
1
,
)
.
tolist
()
)
force_list
=
torch
.
Tensor
(
force_list
)
return
float
(
torch
.
sqrt
(
torch
.
mean
(
torch
.
pow
(
force_list
,
2
))))
def
get_species_wise_force_rms
(
self
,
num_chem_species
:
int
):
"""
Return force rms for each species
Averaged by each components (x, y, z)
"""
assert
self
.
x_is_one_hot_idx
is
True
data_list
=
self
.
to_list
()
atomx
=
torch
.
concat
([
d
[
self
.
DATA_KEY_X
]
for
d
in
data_list
])
force
=
torch
.
concat
([
d
[
self
.
DATA_KEY_FORCE
]
for
d
in
data_list
])
index
=
atomx
.
repeat_interleave
(
3
,
0
).
reshape
(
force
.
shape
)
rms
=
torch
.
zeros
(
(
num_chem_species
,
3
),
dtype
=
force
.
dtype
,
device
=
force
.
device
)
rms
.
scatter_reduce_
(
0
,
index
,
force
.
square
(),
reduce
=
'mean'
,
include_self
=
False
)
return
torch
.
sqrt
(
rms
.
mean
(
dim
=
1
))
def
get_avg_num_neigh
(
self
):
n_neigh
=
[]
for
_
,
data_list
in
self
.
dataset
.
items
():
for
data
in
data_list
:
n_neigh
.
extend
(
np
.
unique
(
data
[
KEY
.
EDGE_IDX
][
0
],
return_counts
=
True
)[
1
]
)
avg_num_neigh
=
np
.
average
(
n_neigh
)
return
avg_num_neigh
def
get_statistics
(
self
,
key
:
str
):
"""
return dict of statistics of given key (energy, force, stress)
key of dict is its label and _total for total statistics
value of dict is dict of statistics (mean, std, median, max, min)
"""
def
_get_statistic_dict
(
tensor_list
):
data_list
=
torch
.
cat
(
[
tensor
.
reshape
(
-
1
,
)
for
tensor
in
tensor_list
]
)
data_list
=
data_list
[
~
torch
.
isnan
(
data_list
)]
return
{
'mean'
:
float
(
torch
.
mean
(
data_list
)),
'std'
:
float
(
torch
.
std
(
data_list
)),
'median'
:
float
(
torch
.
median
(
data_list
)),
'max'
:
(
torch
.
nan
if
data_list
.
numel
()
==
0
else
float
(
torch
.
max
(
data_list
))
),
'min'
:
(
torch
.
nan
if
data_list
.
numel
()
==
0
else
float
(
torch
.
min
(
data_list
))
),
}
res
=
{}
for
label
,
values
in
self
.
dataset
.
items
():
# flatten list of torch.Tensor (values)
tensor_list
=
[
x
[
key
]
for
x
in
values
]
res
[
label
]
=
_get_statistic_dict
(
tensor_list
)
tensor_list
=
[
x
[
key
]
for
x
in
self
.
to_list
()]
res
[
'Total'
]
=
_get_statistic_dict
(
tensor_list
)
return
res
def
augment
(
self
,
dataset
,
validator
:
Optional
[
Callable
]
=
None
):
"""check meta compatibility here
dataset(AtomGraphDataset): data to augment
validator(Callable, Optional): function(self, dataset) -> bool
if validator is None, by default it checks
whether cutoff & chemical_species are same before augment
check consistent data type, float, double, long integer etc
"""
def
default_validator
(
db1
,
db2
):
cut_consis
=
db1
.
cutoff
==
db2
.
cutoff
# compare unordered lists
x_is_not_onehot
=
(
not
db1
.
x_is_one_hot_idx
)
and
(
not
db2
.
x_is_one_hot_idx
)
return
cut_consis
and
x_is_not_onehot
if
validator
is
None
:
validator
=
default_validator
if
not
validator
(
self
,
dataset
):
raise
ValueError
(
'given datasets are not compatible check cutoffs'
)
for
key
,
val
in
dataset
.
items
():
if
key
in
self
.
dataset
:
self
.
dataset
[
key
].
extend
(
val
)
else
:
self
.
dataset
.
update
({
key
:
val
})
self
.
user_labels
=
list
(
self
.
dataset
.
keys
())
def
unify_dtypes
(
self
,
float_dtype
:
torch
.
dtype
=
torch
.
float32
,
int_dtype
:
torch
.
dtype
=
torch
.
int64
):
data_list
=
self
.
to_list
()
for
datum
in
data_list
:
for
k
,
v
in
list
(
datum
.
items
()):
datum
[
k
]
=
util
.
dtype_correct
(
v
,
float_dtype
,
int_dtype
)
def
delete_data_key
(
self
,
key
:
str
):
for
data
in
self
.
to_list
():
del
data
[
key
]
# TODO: this by_label is not straightforward
def
save
(
self
,
path
:
str
,
by_label
:
bool
=
False
):
if
by_label
:
for
label
,
data
in
self
.
dataset
.
items
():
torch
.
save
(
AtomGraphDataset
(
{
label
:
data
},
self
.
cutoff
,
metadata
=
self
.
meta
),
f
'
{
path
}
/
{
label
}
.sevenn_data'
,
)
else
:
if
path
.
endswith
(
'.sevenn_data'
)
is
False
:
path
+=
'.sevenn_data'
torch
.
save
(
self
,
path
)
mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py
0 → 100644
View file @
ca86f720
import
os
import
warnings
from
collections
import
Counter
from
copy
import
deepcopy
from
datetime
import
datetime
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.serialization
import
torch.utils.data
import
yaml
from
ase.data
import
chemical_symbols
from
torch_geometric.data
import
Data
from
torch_geometric.data.in_memory_dataset
import
InMemoryDataset
from
tqdm
import
tqdm
import
sevenn._keys
as
KEY
import
sevenn.train.dataload
as
dataload
import
sevenn.util
as
util
from
sevenn
import
__version__
from
sevenn._const
import
NUM_UNIV_ELEMENT
from
sevenn.atom_graph_data
import
AtomGraphData
from
sevenn.logger
import
Logger
if
torch
.
__version__
.
split
()[
0
]
>=
'2.4.0'
:
# load graph without error
torch
.
serialization
.
add_safe_globals
([
AtomGraphData
])
# warning from PyG, for later torch versions
warnings
.
filterwarnings
(
'ignore'
,
message
=
'You are using `torch.load` with `weights_only=False`'
,
)
def
_tag_graphs
(
graph_list
:
List
[
AtomGraphData
],
tag
:
str
):
"""
WIP: To be used
"""
for
g
in
graph_list
:
g
[
KEY
.
TAG
]
=
tag
return
graph_list
def
pt_to_args
(
pt_filename
:
str
):
"""
Return arg dict of root and processed_name from path to .pt
Usage:
dataset = SevenNetGraphDataset(
**pt_to_args({path}/sevenn_data/dataset.pt)
)
"""
processed_dir
,
basename
=
os
.
path
.
split
(
pt_filename
)
return
{
'root'
:
os
.
path
.
dirname
(
processed_dir
),
'processed_name'
:
os
.
path
.
basename
(
basename
),
}
def
_run_stat
(
graph_list
,
y_keys
:
List
[
str
]
=
[
KEY
.
ENERGY
,
KEY
.
PER_ATOM_ENERGY
,
KEY
.
FORCE
,
KEY
.
STRESS
],
)
->
Dict
[
str
,
Any
]:
"""
Loop over dataset and init any statistics might need
"""
n_neigh
=
[]
natoms_counter
=
Counter
()
composition
=
torch
.
zeros
((
len
(
graph_list
),
NUM_UNIV_ELEMENT
))
stats
:
Dict
[
str
,
Any
]
=
{
y
:
{
'_array'
:
[]}
for
y
in
y_keys
}
for
i
,
graph
in
tqdm
(
enumerate
(
graph_list
),
desc
=
'run_stat'
,
total
=
len
(
graph_list
)
):
z_tensor
=
graph
[
KEY
.
ATOMIC_NUMBERS
]
natoms_counter
.
update
(
z_tensor
.
tolist
())
composition
[
i
]
=
torch
.
bincount
(
z_tensor
,
minlength
=
NUM_UNIV_ELEMENT
)
n_neigh
.
append
(
torch
.
unique
(
graph
[
KEY
.
EDGE_IDX
][
0
],
return_counts
=
True
)[
1
])
for
y
,
dct
in
stats
.
items
():
dct
[
'_array'
].
append
(
graph
[
y
].
reshape
(
-
1
,
)
)
stats
.
update
({
'num_neighbor'
:
{
'_array'
:
n_neigh
}})
for
y
,
dct
in
stats
.
items
():
array
=
torch
.
cat
(
dct
[
'_array'
])
if
array
.
dtype
==
torch
.
int64
:
# because of n_neigh
array
=
array
.
to
(
torch
.
float
)
try
:
median
=
torch
.
quantile
(
array
,
q
=
0.5
)
except
RuntimeError
:
warnings
.
warn
(
f
'skip median due to too large tensor size:
{
y
}
'
)
median
=
torch
.
nan
dct
.
update
(
{
'mean'
:
float
(
torch
.
mean
(
array
)),
'std'
:
float
(
torch
.
std
(
array
,
correction
=
0
)),
'median'
:
float
(
median
),
'max'
:
float
(
torch
.
max
(
array
)),
'min'
:
float
(
torch
.
min
(
array
)),
'count'
:
array
.
numel
(),
'_array'
:
array
,
}
)
natoms
=
{
chemical_symbols
[
int
(
z
)]:
cnt
for
z
,
cnt
in
natoms_counter
.
items
()}
natoms
[
'total'
]
=
sum
(
list
(
natoms
.
values
()))
stats
.
update
({
'_composition'
:
composition
,
'natoms'
:
natoms
})
return
stats
def
_elemwise_reference_energies
(
composition
:
np
.
ndarray
,
energies
:
np
.
ndarray
):
from
sklearn.linear_model
import
Ridge
c
=
composition
y
=
energies
zero_indices
=
np
.
all
(
c
==
0
,
axis
=
0
)
c_reduced
=
c
[:,
~
zero_indices
]
# will not 100% reproduce, as it is sorted by Z
# train/dataset.py was sorted by alphabets of chemical species
coef_reduced
=
Ridge
(
alpha
=
0.1
,
fit_intercept
=
False
).
fit
(
c_reduced
,
y
).
coef_
full_coeff
=
np
.
zeros
(
NUM_UNIV_ELEMENT
)
full_coeff
[
~
zero_indices
]
=
coef_reduced
return
full_coeff
.
tolist
()
# ex: full_coeff[1] = H_reference_energy
class
SevenNetGraphDataset
(
InMemoryDataset
):
"""
Replacement of AtomGraphDataset. (and .sevenn_data)
Extends InMemoryDataset of PyG. From given 'files', and 'cutoff',
build graphs for training SevenNet model. Preprocessed graphs are saved to
f'{root}/sevenn_data/{processed_name}.pt
TODO: Save meta info (cutoff) by overriding .save and .load
TODO: 'tag' is not used yet, but initialized
'tag' is replacement for 'label', and each datapoint has it as integer
'tag' is usually parsed from if the structure_list of load_dataset
Args:
root: path to save/load processed PyG dataset
cutoff: edge cutoff of given AtomGraphData
files: list of filenames or dict describing how to parse the file
ASE readable (with proper extension), structure_list, .sevenn_data,
dict containing file_list (see dict_reader of train/dataload.py)
process_num_cores: # of cpu cores to build graph
processed_name: save as {root}/sevenn_data/{processed_name}.pt
pre_transfrom: optional transform for each graph: def (graph) -> graph
pre_filter: optional filtering function for each graph: def (graph) -> graph
force_reload: if True, reload dataset from files even if there exist
{root}/sevenn_data/{processed_name}
**process_kwargs: keyword arguments that will be passed into ase.io.read
"""
def
__init__
(
self
,
cutoff
:
float
,
root
:
Optional
[
str
]
=
None
,
files
:
Optional
[
Union
[
str
,
List
[
Any
]]]
=
None
,
process_num_cores
:
int
=
1
,
processed_name
:
str
=
'graph.pt'
,
transform
:
Optional
[
Callable
]
=
None
,
pre_transform
:
Optional
[
Callable
]
=
None
,
pre_filter
:
Optional
[
Callable
]
=
None
,
use_data_weight
:
bool
=
False
,
log
:
bool
=
True
,
force_reload
:
bool
=
False
,
drop_info
:
bool
=
True
,
**
process_kwargs
,
):
self
.
cutoff
=
cutoff
if
files
is
None
:
files
=
[]
elif
isinstance
(
files
,
str
):
files
=
[
files
]
# user convenience
_files
=
[]
for
f
in
files
:
if
isinstance
(
f
,
str
):
f
=
os
.
path
.
abspath
(
f
)
_files
.
append
(
f
)
self
.
_files
=
_files
self
.
_full_file_list
=
[]
if
not
processed_name
.
endswith
(
'.pt'
):
processed_name
+=
'.pt'
self
.
_processed_names
=
[
processed_name
,
# {root}/sevenn_data/{name}.pt
processed_name
.
replace
(
'.pt'
,
'.yaml'
),
]
root
=
root
or
'./'
_pdir
=
os
.
path
.
join
(
root
,
'sevenn_data'
)
_pt
=
os
.
path
.
join
(
_pdir
,
self
.
_processed_names
[
0
])
if
not
os
.
path
.
exists
(
_pt
)
and
len
(
self
.
_files
)
==
0
:
raise
ValueError
(
(
f
'
{
_pt
}
not found and no files to process. '
+
'If you copied only .pt file, please copy '
+
'whole sevenn_data dir without changing its name.'
+
' They all work together.'
)
)
_yam
=
os
.
path
.
join
(
_pdir
,
self
.
_processed_names
[
1
])
if
not
os
.
path
.
exists
(
_yam
)
and
len
(
self
.
_files
)
==
0
:
raise
ValueError
(
f
'
{
_yam
}
not found and no files to process'
)
self
.
process_num_cores
=
process_num_cores
self
.
process_kwargs
=
process_kwargs
self
.
use_data_weight
=
use_data_weight
self
.
drop_info
=
drop_info
self
.
tag_map
=
{}
self
.
statistics
=
{}
self
.
finalized
=
False
super
().
__init__
(
root
,
transform
,
pre_transform
,
pre_filter
,
log
=
log
,
force_reload
=
force_reload
,
)
# Internally calls 'process'
self
.
load
(
self
.
processed_paths
[
0
])
# load pt, saved after process
def
load
(
self
,
path
:
str
,
data_cls
=
Data
)
->
None
:
super
().
load
(
path
,
data_cls
)
if
len
(
self
)
==
0
:
warnings
.
warn
(
f
'No graphs found
{
self
.
processed_paths
[
0
]
}
'
)
if
len
(
self
.
statistics
)
==
0
:
# dataset is loaded from existing pt file.
self
.
_load_meta
()
def
_load_meta
(
self
)
->
None
:
with
open
(
self
.
processed_paths
[
1
],
'r'
)
as
f
:
meta
=
yaml
.
safe_load
(
f
)
if
meta
[
'sevennet_version'
]
==
'0.10.0'
:
self
.
_save_meta
(
list
(
self
))
with
open
(
self
.
processed_paths
[
1
],
'r'
)
as
f
:
meta
=
yaml
.
safe_load
(
f
)
cutoff
=
float
(
meta
[
'cutoff'
])
if
float
(
meta
[
'cutoff'
])
!=
self
.
cutoff
:
warnings
.
warn
(
(
'Loaded dataset is built with different cutoff length: '
+
f
'
{
cutoff
}
!=
{
self
.
cutoff
}
, dataset cutoff will be'
+
f
' overwritten to
{
cutoff
}
'
)
)
self
.
cutoff
=
cutoff
self
.
_files
=
meta
[
'files'
]
self
.
statistics
=
meta
[
'statistics'
]
def
__getitem__
(
self
,
idx
):
graph
=
super
().
__getitem__
(
idx
)
if
self
.
drop_info
:
graph
.
pop
(
KEY
.
INFO
,
None
)
# type: ignore
return
graph
@
property
def
raw_file_names
(
self
)
->
List
[
Any
]:
return
self
.
_files
@
property
def
processed_file_names
(
self
)
->
List
[
str
]:
return
self
.
_processed_names
@
property
def
processed_dir
(
self
)
->
str
:
return
os
.
path
.
join
(
self
.
root
,
'sevenn_data'
)
@
property
def
full_file_list
(
self
)
->
Union
[
List
[
str
],
None
]:
return
self
.
_full_file_list
def
process
(
self
):
graph_list
:
List
[
AtomGraphData
]
=
[]
for
file
in
self
.
raw_file_names
:
tmplist
=
SevenNetGraphDataset
.
file_to_graph_list
(
file
=
file
,
cutoff
=
self
.
cutoff
,
num_cores
=
self
.
process_num_cores
,
**
self
.
process_kwargs
,
)
if
isinstance
(
file
,
str
)
and
self
.
_full_file_list
is
not
None
:
self
.
_full_file_list
.
extend
([
os
.
path
.
abspath
(
file
)]
*
len
(
tmplist
))
else
:
self
.
_full_file_list
=
None
graph_list
.
extend
(
tmplist
)
processed_graph_list
=
[]
for
data
in
graph_list
:
if
self
.
pre_filter
is
not
None
and
not
self
.
pre_filter
(
data
):
continue
if
self
.
pre_transform
is
not
None
:
data
=
self
.
pre_transform
(
data
)
if
self
.
use_data_weight
:
# pop data weight from info, and assign to graph
weight
=
data
[
KEY
.
INFO
].
pop
(
KEY
.
DATA_WEIGHT
,
{
'energy'
:
1.0
,
'force'
:
1.0
,
'stress'
:
1.0
}
)
data
[
KEY
.
DATA_WEIGHT
]
=
weight
processed_graph_list
.
append
(
data
)
if
len
(
processed_graph_list
)
==
0
:
# Can not save at all if there is no graph (error in PyG), raise an error
raise
ValueError
(
'Zero graph found after filtering'
)
# save graphs, handled by torch_geometrics
self
.
save
(
processed_graph_list
,
self
.
processed_paths
[
0
])
self
.
_save_meta
(
processed_graph_list
)
if
self
.
log
:
Logger
().
writeline
(
f
'Dataset is saved:
{
self
.
processed_paths
[
0
]
}
'
)
def
_save_meta
(
self
,
graph_list
)
->
None
:
stats
=
_run_stat
(
graph_list
)
stats
[
'elemwise_reference_energies'
]
=
_elemwise_reference_energies
(
stats
[
'_composition'
].
numpy
(),
stats
[
KEY
.
ENERGY
][
'_array'
].
numpy
()
)
self
.
statistics
=
stats
stats_save
=
{}
for
label
,
dct
in
self
.
statistics
.
items
():
if
label
.
startswith
(
'_'
):
continue
stats_save
[
label
]
=
{}
if
not
isinstance
(
dct
,
dict
):
stats_save
[
label
]
=
dct
else
:
for
k
,
v
in
dct
.
items
():
if
k
.
startswith
(
'_'
):
continue
stats_save
[
label
][
k
]
=
v
meta
=
{
'sevennet_version'
:
__version__
,
'cutoff'
:
self
.
cutoff
,
'when'
:
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M'
),
'files'
:
self
.
_files
,
'statistics'
:
stats_save
,
'species'
:
self
.
species
,
'num_graphs'
:
self
.
statistics
[
KEY
.
ENERGY
][
'count'
],
'per_atom_energy_mean'
:
self
.
per_atom_energy_mean
,
'force_rms'
:
self
.
force_rms
,
'per_atom_energy_std'
:
self
.
per_atom_energy_std
,
'avg_num_neigh'
:
self
.
avg_num_neigh
,
'sqrt_avg_num_neigh'
:
self
.
sqrt_avg_num_neigh
,
}
with
open
(
self
.
processed_paths
[
1
],
'w'
)
as
f
:
yaml
.
dump
(
meta
,
f
,
default_flow_style
=
False
)
@
property
def
species
(
self
):
return
[
z
for
z
in
self
.
statistics
[
'natoms'
].
keys
()
if
z
!=
'total'
]
@
property
def
natoms
(
self
):
return
self
.
statistics
[
'natoms'
]
@
property
def
per_atom_energy_mean
(
self
):
return
self
.
statistics
[
KEY
.
PER_ATOM_ENERGY
][
'mean'
]
@
property
def
elemwise_reference_energies
(
self
):
return
self
.
statistics
[
'elemwise_reference_energies'
]
@
property
def
force_rms
(
self
):
mean
=
self
.
statistics
[
KEY
.
FORCE
][
'mean'
]
std
=
self
.
statistics
[
KEY
.
FORCE
][
'std'
]
return
float
((
mean
**
2
+
std
**
2
)
**
(
0.5
))
@
property
def
per_atom_energy_std
(
self
):
return
self
.
statistics
[
'per_atom_energy'
][
'std'
]
@
property
def
avg_num_neigh
(
self
):
return
self
.
statistics
[
'num_neighbor'
][
'mean'
]
@
property
def
sqrt_avg_num_neigh
(
self
):
return
self
.
avg_num_neigh
**
0.5
@
staticmethod
def
_read_sevenn_data
(
filename
:
str
)
->
Tuple
[
List
[
AtomGraphData
],
float
]:
# backward compatibility
from
sevenn.train.dataset
import
AtomGraphDataset
dataset
=
torch
.
load
(
filename
,
map_location
=
'cpu'
,
weights_only
=
False
)
if
isinstance
(
dataset
,
AtomGraphDataset
):
graph_list
=
[]
for
_
,
graphs
in
dataset
.
dataset
.
items
():
# type: ignore
# TODO: transfer label to tag (who gonna need this?)
graph_list
.
extend
(
graphs
)
return
graph_list
,
dataset
.
cutoff
else
:
raise
ValueError
(
f
'Not sevenn_data type:
{
type
(
dataset
)
}
'
)
@
staticmethod
def
_read_structure_list
(
filename
:
str
,
cutoff
:
float
,
num_cores
:
int
=
1
)
->
List
[
AtomGraphData
]:
datadct
=
dataload
.
structure_list_reader
(
filename
)
graph_list
=
[]
for
tag
,
atoms_list
in
datadct
.
items
():
tmp
=
dataload
.
graph_build
(
atoms_list
,
cutoff
,
num_cores
)
graph_list
.
extend
(
_tag_graphs
(
tmp
,
tag
))
return
graph_list
@
staticmethod
def
_read_ase_readable
(
filename
:
str
,
cutoff
:
float
,
num_cores
:
int
=
1
,
tag
:
str
=
''
,
transfer_info
:
bool
=
True
,
allow_unlabeled
:
bool
=
False
,
**
ase_kwargs
,
)
->
List
[
AtomGraphData
]:
pbc_override
=
ase_kwargs
.
pop
(
'pbc'
,
None
)
atoms_list
=
dataload
.
ase_reader
(
filename
,
**
ase_kwargs
)
for
atoms
in
atoms_list
:
if
pbc_override
is
not
None
:
atoms
.
pbc
=
pbc_override
graph_list
=
dataload
.
graph_build
(
atoms_list
,
cutoff
,
num_cores
,
transfer_info
=
transfer_info
,
allow_unlabeled
=
allow_unlabeled
,
)
if
tag
!=
''
:
graph_list
=
_tag_graphs
(
graph_list
,
tag
)
return
graph_list
@
staticmethod
def
_read_graph_dataset
(
filename
:
str
,
cutoff
:
float
,
**
kwargs
)
->
List
[
AtomGraphData
]:
meta_f
=
filename
.
replace
(
'.pt'
,
'.yaml'
)
orig_cutoff
=
cutoff
if
not
os
.
path
.
exists
(
filename
):
raise
FileNotFoundError
(
f
'No such file:
{
filename
}
'
)
if
not
os
.
path
.
exists
(
meta_f
):
warnings
.
warn
(
'No meta info found, beware of cutoff...'
)
else
:
with
open
(
meta_f
,
'r'
)
as
f
:
meta
=
yaml
.
safe_load
(
f
)
orig_cutoff
=
float
(
meta
[
'cutoff'
])
if
orig_cutoff
!=
cutoff
:
warnings
.
warn
(
f
'
{
filename
}
has different cutoff length: '
+
f
'
{
cutoff
}
!=
{
orig_cutoff
}
'
)
ds_args
:
dict
[
str
,
Any
]
=
dict
({
'cutoff'
:
orig_cutoff
})
ds_args
.
update
(
pt_to_args
(
filename
))
ds_args
.
update
(
kwargs
)
dataset
=
SevenNetGraphDataset
(
**
ds_args
)
# TODO: hard coded. consult with inference.py
glist
=
[
g
.
fit_dimension
()
for
g
in
dataset
]
# type: ignore
for
g
in
glist
:
if
KEY
.
STRESS
in
g
:
# (1, 6) is what we want
g
[
KEY
.
STRESS
]
=
g
[
KEY
.
STRESS
].
unsqueeze
(
0
)
return
glist
@
staticmethod
def
_read_dict
(
data_dict
:
dict
,
cutoff
:
float
,
num_cores
:
int
=
1
,
):
# logic same as the dataload dict_reader, but handles graphs
data_dict_cp
=
deepcopy
(
data_dict
)
file_list
=
data_dict_cp
.
get
(
'file_list'
,
None
)
if
file_list
is
None
:
raise
KeyError
(
'file_list is not found'
)
data_weight_default
=
{
'energy'
:
1.0
,
'force'
:
1.0
,
'stress'
:
1.0
,
}
data_weight
=
data_weight_default
.
copy
()
data_weight
.
update
(
data_dict_cp
.
pop
(
KEY
.
DATA_WEIGHT
,
{}))
graph_list
=
[]
for
file_dct
in
file_list
:
ftype
=
file_dct
.
pop
(
'data_format'
,
'ase'
)
if
ftype
!=
'graph'
:
continue
graph_list
.
extend
(
SevenNetGraphDataset
.
_read_graph_dataset
(
file_dct
.
get
(
'file'
),
cutoff
=
cutoff
)
)
for
graph
in
graph_list
:
if
KEY
.
INFO
not
in
graph
:
graph
[
KEY
.
INFO
]
=
{}
graph
[
KEY
.
INFO
].
update
(
data_dict_cp
)
graph
[
KEY
.
INFO
].
update
({
KEY
.
DATA_WEIGHT
:
data_weight
})
atoms_list
=
dataload
.
dict_reader
(
data_dict
)
graph_list
.
extend
(
dataload
.
graph_build
(
atoms_list
,
cutoff
,
num_cores
))
return
graph_list
@
staticmethod
def
file_to_graph_list
(
file
:
Union
[
str
,
dict
],
cutoff
:
float
,
num_cores
:
int
=
1
,
**
kwargs
)
->
List
[
AtomGraphData
]:
"""
kwargs: if file is ase readable, passed to ase.io.read
"""
if
isinstance
(
file
,
str
)
and
not
os
.
path
.
isfile
(
file
):
raise
ValueError
(
f
'No such file:
{
file
}
'
)
graph_list
:
List
[
AtomGraphData
]
if
isinstance
(
file
,
dict
):
graph_list
=
SevenNetGraphDataset
.
_read_dict
(
file
,
cutoff
,
num_cores
,
**
kwargs
)
elif
file
.
endswith
(
'.pt'
):
graph_list
=
SevenNetGraphDataset
.
_read_graph_dataset
(
file
,
cutoff
)
elif
file
.
endswith
(
'.sevenn_data'
):
graph_list
,
cutoff_other
=
SevenNetGraphDataset
.
_read_sevenn_data
(
file
)
if
cutoff_other
!=
cutoff
:
warnings
.
warn
(
f
'Given
{
file
}
has different
{
cutoff_other
}
!'
)
cutoff
=
cutoff_other
elif
'structure_list'
in
file
:
graph_list
=
SevenNetGraphDataset
.
_read_structure_list
(
file
,
cutoff
,
num_cores
)
else
:
graph_list
=
SevenNetGraphDataset
.
_read_ase_readable
(
file
,
cutoff
,
num_cores
,
**
kwargs
)
return
graph_list
def
from_single_path
(
path
:
Union
[
str
,
List
],
override_data_weight
:
bool
=
True
,
**
dataset_kwargs
)
->
Union
[
SevenNetGraphDataset
,
None
]:
"""
Convenient routine for loading a single .pt dataset.
If given dict and it has data_weight, apply it using transform
"""
data_weight
=
{
'energy'
:
1.0
,
'force'
:
1.0
,
'stress'
:
1.0
}
spath
=
_extract_single_path
(
path
)
if
spath
is
None
:
return
None
if
isinstance
(
spath
,
str
):
if
not
spath
.
endswith
(
'.pt'
):
return
None
dataset_kwargs
.
update
(
pt_to_args
(
spath
))
elif
isinstance
(
spath
,
dict
):
file
=
_extract_file_from_dict
(
spath
)
if
file
is
None
or
not
file
.
endswith
(
'.pt'
):
return
None
dataset_kwargs
.
update
(
pt_to_args
(
file
))
data_weight_user
=
spath
.
get
(
KEY
.
DATA_WEIGHT
,
None
)
if
data_weight_user
is
not
None
:
data_weight
.
update
(
data_weight_user
)
else
:
return
None
if
override_data_weight
:
dataset_kwargs
[
'transform'
]
=
_chain_data_weight_override
(
dataset_kwargs
.
get
(
'transform'
),
data_weight
)
return
SevenNetGraphDataset
(
**
dataset_kwargs
)
def
_extract_single_path
(
path
:
Union
[
str
,
List
])
->
Union
[
str
,
dict
,
None
]:
"""Extracts a single path from the input,
ensuring it's either a single string or list with one item."""
if
isinstance
(
path
,
list
):
return
path
[
0
]
if
len
(
path
)
==
1
else
None
return
path
if
isinstance
(
path
,
(
str
,
dict
))
else
None
def
_extract_file_from_dict
(
path_dict
:
dict
)
->
Union
[
str
,
None
]:
"""Extracts a single file path from the dictionary, ensuring it's valid."""
file_list
=
path_dict
.
get
(
'file_list'
,
None
)
if
file_list
and
len
(
file_list
)
==
1
:
file
=
file_list
[
0
].
get
(
'file'
,
None
)
return
file
if
isinstance
(
file
,
str
)
else
None
return
None
def
_chain_data_weight_override
(
transform_func
,
data_weight
):
"""Creates a transform function that overrides the data weight."""
def
chained_transform
(
graph
):
graph
=
transform_func
(
graph
)
if
transform_func
is
not
None
else
graph
graph
[
KEY
.
INFO
].
pop
(
KEY
.
DATA_WEIGHT
,
None
)
graph
[
KEY
.
DATA_WEIGHT
]
=
data_weight
return
graph
return
chained_transform
# script, return dict of SevenNetGraphDataset
def
from_config
(
config
:
Dict
[
str
,
Any
],
working_dir
:
str
=
os
.
getcwd
(),
dataset_keys
:
Optional
[
List
[
str
]]
=
None
,
):
log
=
Logger
()
if
dataset_keys
is
None
:
dataset_keys
=
[]
for
k
in
config
:
if
k
.
startswith
(
'load_'
)
and
k
.
endswith
(
'_path'
):
dataset_keys
.
append
(
k
)
if
KEY
.
LOAD_TRAINSET
not
in
dataset_keys
:
raise
ValueError
(
f
'
{
KEY
.
LOAD_TRAINSET
}
must be present in config'
)
# initialize arguments for loading dataset
dataset_args
=
{
'cutoff'
:
config
[
KEY
.
CUTOFF
],
'root'
:
working_dir
,
'process_num_cores'
:
config
.
get
(
KEY
.
PREPROCESS_NUM_CORES
,
1
),
'use_data_weight'
:
config
.
get
(
KEY
.
USE_WEIGHT
,
False
),
**
config
.
get
(
KEY
.
DATA_FORMAT_ARGS
,
{}),
}
datasets
=
{}
for
dk
in
dataset_keys
:
if
not
(
paths
:
=
config
[
dk
]):
continue
if
isinstance
(
paths
,
str
):
paths
=
[
paths
]
name
=
'_'
.
join
([
nn
.
strip
()
for
nn
in
dk
.
split
(
'_'
)[
1
:
-
1
]])
if
(
dataset
:
=
from_single_path
(
paths
,
**
dataset_args
))
is
not
None
:
datasets
[
name
]
=
dataset
else
:
dataset_args
.
update
({
'files'
:
paths
,
'processed_name'
:
name
})
dataset_path
=
os
.
path
.
join
(
working_dir
,
'sevenn_data'
,
f
'
{
name
}
.pt'
)
if
os
.
path
.
exists
(
dataset_path
)
and
'force_reload'
not
in
dataset_args
:
log
.
writeline
(
f
'Dataset will be loaded from
{
dataset_path
}
, without update. '
+
'If you have changed your files to read, put force_reload=True'
+
' under the data_format_args key'
)
datasets
[
name
]
=
SevenNetGraphDataset
(
**
dataset_args
)
train_set
=
datasets
[
'trainset'
]
chem_species
=
set
(
train_set
.
species
)
# print statistics of each dataset
for
name
,
dataset
in
datasets
.
items
():
log
.
bar
()
log
.
writeline
(
f
'
{
name
}
distribution:'
)
log
.
statistic_write
(
dataset
.
statistics
)
log
.
format_k_v
(
'# structures (graph)'
,
len
(
dataset
),
write
=
True
)
chem_species
.
update
(
dataset
.
species
)
log
.
bar
()
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys
=
[
KEY
.
CHEMICAL_SPECIES
,
KEY
.
NUM_SPECIES
,
KEY
.
TYPE_MAP
]
if
all
([
config
[
ck
]
==
'auto'
for
ck
in
chem_keys
]):
# see parse_input.py
log
.
writeline
(
'Known species are obtained from the dataset'
)
config
.
update
(
util
.
chemical_species_preprocess
(
sorted
(
list
(
chem_species
))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats
=
[
KEY
.
SHIFT
,
KEY
.
SCALE
,
KEY
.
CONV_DENOMINATOR
]
for
k
in
init_from_stats
:
input
=
config
[
k
]
# statistic key or numbers
# If it is not 'str', 1: It is 'continue' training
# 2: User manually inserted numbers
if
isinstance
(
input
,
str
)
and
hasattr
(
train_set
,
input
):
var
=
getattr
(
train_set
,
input
)
config
.
update
({
k
:
var
})
log
.
writeline
(
f
'
{
k
}
is obtained from statistics'
)
elif
isinstance
(
input
,
str
)
and
not
hasattr
(
train_set
,
input
):
raise
NotImplementedError
(
input
)
if
'validset'
not
in
datasets
and
config
.
get
(
KEY
.
RATIO
,
0.0
)
>
0.0
:
log
.
writeline
(
'Use validation set as random split from the training set'
)
log
.
writeline
(
'Note that statistics, shift, scale, and conv_denominator are '
+
'computed before random split.
\n
If you want these after random '
+
'split, please preprocess dataset and set it as load_trainset_path '
+
'and load_validset_path explicitly.'
)
ratio
=
float
(
config
[
KEY
.
RATIO
])
train
,
valid
=
torch
.
utils
.
data
.
random_split
(
datasets
[
'trainset'
],
(
1.0
-
ratio
,
ratio
)
)
datasets
[
'trainset'
]
=
train
datasets
[
'validset'
]
=
valid
return
datasets
mace-bench/3rdparty/SevenNet/sevenn/train/loss.py
0 → 100644
View file @
ca86f720
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
import
torch
import
sevenn._keys
as
KEY
class
LossDefinition
:
"""
Base class for loss definition
weights are defined in outside of the class
"""
def
__init__
(
self
,
name
:
str
,
unit
:
Optional
[
str
]
=
None
,
criterion
:
Optional
[
Callable
]
=
None
,
ref_key
:
Optional
[
str
]
=
None
,
pred_key
:
Optional
[
str
]
=
None
,
use_weight
:
bool
=
False
,
ignore_unlabeled
:
bool
=
True
,
):
self
.
name
=
name
self
.
unit
=
unit
self
.
criterion
=
criterion
self
.
ref_key
=
ref_key
self
.
pred_key
=
pred_key
self
.
use_weight
=
use_weight
self
.
ignore_unlabeled
=
ignore_unlabeled
def
__repr__
(
self
):
return
self
.
name
def
assign_criteria
(
self
,
criterion
:
Callable
):
if
self
.
criterion
is
not
None
:
raise
ValueError
(
'Loss uses its own criterion.'
)
self
.
criterion
=
criterion
def
_preprocess
(
self
,
batch_data
:
Dict
[
str
,
Any
],
model
:
Optional
[
Callable
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
self
.
pred_key
is
None
or
self
.
ref_key
is
None
:
raise
NotImplementedError
(
'LossDefinition is not implemented.'
)
pred
=
torch
.
reshape
(
batch_data
[
self
.
pred_key
],
(
-
1
,))
ref
=
torch
.
reshape
(
batch_data
[
self
.
ref_key
],
(
-
1
,))
return
pred
,
ref
,
None
def
_ignore_unlabeled
(
self
,
pred
,
ref
,
data_weights
=
None
):
unlabeled
=
torch
.
isnan
(
ref
)
pred
=
pred
[
~
unlabeled
]
ref
=
ref
[
~
unlabeled
]
if
data_weights
is
not
None
:
data_weights
=
data_weights
[
~
unlabeled
]
return
pred
,
ref
,
data_weights
def
get_loss
(
self
,
batch_data
:
Dict
[
str
,
Any
],
model
:
Optional
[
Callable
]
=
None
):
"""
Function that return scalar
"""
if
self
.
criterion
is
None
:
raise
NotImplementedError
(
'LossDefinition has no criterion.'
)
pred
,
ref
,
w_tensor
=
self
.
_preprocess
(
batch_data
,
model
)
if
self
.
ignore_unlabeled
:
pred
,
ref
,
w_tensor
=
self
.
_ignore_unlabeled
(
pred
,
ref
,
w_tensor
)
if
len
(
pred
)
==
0
:
assert
self
.
ref_key
is
not
None
return
torch
.
zeros
(
1
,
device
=
batch_data
[
self
.
ref_key
].
device
)
loss
=
self
.
criterion
(
pred
,
ref
)
if
self
.
use_weight
:
loss
=
torch
.
mean
(
loss
*
w_tensor
)
return
loss
class
PerAtomEnergyLoss
(
LossDefinition
):
"""
Loss for per atom energy
"""
def
__init__
(
self
,
name
:
str
=
'Energy'
,
unit
:
str
=
'eV/atom'
,
criterion
:
Optional
[
Callable
]
=
None
,
ref_key
:
str
=
KEY
.
ENERGY
,
pred_key
:
str
=
KEY
.
PRED_TOTAL_ENERGY
,
**
kwargs
,
):
super
().
__init__
(
name
=
name
,
unit
=
unit
,
criterion
=
criterion
,
ref_key
=
ref_key
,
pred_key
=
pred_key
,
**
kwargs
,
)
def
_preprocess
(
self
,
batch_data
:
Dict
[
str
,
Any
],
model
:
Optional
[
Callable
]
=
None
):
num_atoms
=
batch_data
[
KEY
.
NUM_ATOMS
]
assert
isinstance
(
self
.
pred_key
,
str
)
and
isinstance
(
self
.
ref_key
,
str
)
pred
=
batch_data
[
self
.
pred_key
]
/
num_atoms
ref
=
batch_data
[
self
.
ref_key
]
/
num_atoms
w_tensor
=
None
if
self
.
use_weight
:
loss_type
=
self
.
name
.
lower
()
weight
=
batch_data
[
KEY
.
DATA_WEIGHT
][
loss_type
]
w_tensor
=
torch
.
repeat_interleave
(
weight
,
1
)
return
pred
,
ref
,
w_tensor
class
ForceLoss
(
LossDefinition
):
"""
Loss for force
"""
def
__init__
(
self
,
name
:
str
=
'Force'
,
unit
:
str
=
'eV/A'
,
criterion
:
Optional
[
Callable
]
=
None
,
ref_key
:
str
=
KEY
.
FORCE
,
pred_key
:
str
=
KEY
.
PRED_FORCE
,
**
kwargs
,
):
super
().
__init__
(
name
=
name
,
unit
=
unit
,
criterion
=
criterion
,
ref_key
=
ref_key
,
pred_key
=
pred_key
,
**
kwargs
,
)
def
_preprocess
(
self
,
batch_data
:
Dict
[
str
,
Any
],
model
:
Optional
[
Callable
]
=
None
):
assert
isinstance
(
self
.
pred_key
,
str
)
and
isinstance
(
self
.
ref_key
,
str
)
pred
=
torch
.
reshape
(
batch_data
[
self
.
pred_key
],
(
-
1
,))
ref
=
torch
.
reshape
(
batch_data
[
self
.
ref_key
],
(
-
1
,))
w_tensor
=
None
if
self
.
use_weight
:
loss_type
=
self
.
name
.
lower
()
weight
=
batch_data
[
KEY
.
DATA_WEIGHT
][
loss_type
]
w_tensor
=
weight
[
batch_data
[
KEY
.
BATCH
]]
w_tensor
=
torch
.
repeat_interleave
(
w_tensor
,
3
)
return
pred
,
ref
,
w_tensor
class
StressLoss
(
LossDefinition
):
"""
Loss for stress this is kbar
"""
def
__init__
(
self
,
name
:
str
=
'Stress'
,
unit
:
str
=
'kbar'
,
criterion
:
Optional
[
Callable
]
=
None
,
ref_key
:
str
=
KEY
.
STRESS
,
pred_key
:
str
=
KEY
.
PRED_STRESS
,
**
kwargs
,
):
super
().
__init__
(
name
=
name
,
unit
=
unit
,
criterion
=
criterion
,
ref_key
=
ref_key
,
pred_key
=
pred_key
,
**
kwargs
,
)
self
.
TO_KB
=
1602.1766208
# eV/A^3 to kbar
def
_preprocess
(
self
,
batch_data
:
Dict
[
str
,
Any
],
model
:
Optional
[
Callable
]
=
None
):
assert
isinstance
(
self
.
pred_key
,
str
)
and
isinstance
(
self
.
ref_key
,
str
)
pred
=
torch
.
reshape
(
batch_data
[
self
.
pred_key
]
*
self
.
TO_KB
,
(
-
1
,))
ref
=
torch
.
reshape
(
batch_data
[
self
.
ref_key
]
*
self
.
TO_KB
,
(
-
1
,))
w_tensor
=
None
if
self
.
use_weight
:
loss_type
=
self
.
name
.
lower
()
weight
=
batch_data
[
KEY
.
DATA_WEIGHT
][
loss_type
]
w_tensor
=
torch
.
repeat_interleave
(
weight
,
6
)
return
pred
,
ref
,
w_tensor
def
get_loss_functions_from_config
(
config
:
Dict
[
str
,
Any
]):
from
sevenn.train.optim
import
loss_dict
loss_functions
=
[]
# list of tuples (loss_definition, weight)
loss
=
loss_dict
[
config
[
KEY
.
LOSS
].
lower
()]
loss_param
=
config
.
get
(
KEY
.
LOSS_PARAM
,
{})
use_weight
=
config
.
get
(
KEY
.
USE_WEIGHT
,
False
)
if
use_weight
:
loss_param
[
'reduction'
]
=
'none'
criterion
=
loss
(
**
loss_param
)
commons
=
{
'use_weight'
:
use_weight
}
loss_functions
.
append
((
PerAtomEnergyLoss
(
**
commons
),
1.0
))
loss_functions
.
append
((
ForceLoss
(
**
commons
),
config
[
KEY
.
FORCE_WEIGHT
]))
if
config
[
KEY
.
IS_TRAIN_STRESS
]:
loss_functions
.
append
((
StressLoss
(
**
commons
),
config
[
KEY
.
STRESS_WEIGHT
]))
for
loss_function
,
_
in
loss_functions
:
# why do these?
if
loss_function
.
criterion
is
None
:
loss_function
.
assign_criteria
(
criterion
)
return
loss_functions
mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py
0 → 100644
View file @
ca86f720
import
bisect
import
os
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
as
np
from
torch.utils.data
import
ConcatDataset
,
Dataset
import
sevenn._keys
as
KEY
import
sevenn.util
as
util
from
sevenn.logger
import
Logger
def
_arrange_paths_by_modality
(
paths
:
List
[
dict
]):
modal_dct
=
{}
for
path
in
paths
:
if
isinstance
(
path
,
dict
):
if
KEY
.
DATA_MODALITY
not
in
path
:
raise
ValueError
(
f
'
{
KEY
.
DATA_MODALITY
}
is missing'
)
modal
=
path
.
pop
(
KEY
.
DATA_MODALITY
)
else
:
raise
TypeError
(
f
'
{
path
}
is not dict or str'
)
if
modal
not
in
modal_dct
:
modal_dct
[
modal
]
=
[]
modal_dct
[
modal
].
append
(
path
)
return
modal_dct
def
combined_variance
(
means
:
np
.
ndarray
,
stds
:
np
.
ndarray
,
sample_sizes
:
np
.
ndarray
,
ddof
:
int
=
0
)
->
float
:
"""
Calculate the combined variance for multiple datasets.
"""
assert
len
(
means
)
==
len
(
stds
)
and
len
(
stds
)
==
len
(
sample_sizes
)
# Total number of samples
total_samples
=
np
.
sum
(
sample_sizes
)
# Combined mean
combined_mean
=
np
.
sum
(
sample_sizes
*
means
)
/
total_samples
# Combined variance calculation
variance_terms
=
(
sample_sizes
-
ddof
)
*
(
stds
**
2
)
mean_diff_terms
=
sample_sizes
*
((
means
-
combined_mean
)
**
2
)
combined_variance
=
(
np
.
sum
(
variance_terms
)
+
np
.
sum
(
mean_diff_terms
))
/
(
total_samples
-
ddof
)
return
combined_variance
def
combined_std
(
means
:
List
[
float
],
stds
:
List
[
float
],
sample_sizes
:
List
[
int
]
)
->
float
:
"""
Calculate the combined std for multiple datasets.
"""
assert
len
(
means
)
==
len
(
stds
)
and
len
(
stds
)
==
len
(
sample_sizes
)
means_arr
=
np
.
array
(
means
)
stds_arr
=
np
.
array
(
stds
)
sample_sizes_arr
=
np
.
array
(
sample_sizes
)
cv
=
combined_variance
(
means_arr
,
stds_arr
,
sample_sizes_arr
)
return
np
.
sqrt
(
cv
)
def
combined_mean
(
means
:
List
[
float
],
sample_sizes
:
List
[
int
])
->
float
:
"""
Calculate the combined mean for multiple datasets.
"""
assert
len
(
means
)
==
len
(
sample_sizes
)
means_arr
=
np
.
array
(
means
)
sample_sizes_arr
=
np
.
array
(
sample_sizes
)
return
np
.
sum
(
sample_sizes_arr
*
means_arr
)
/
np
.
sum
(
sample_sizes_arr
)
def
combined_rms
(
means
:
List
[
float
],
stds
:
List
[
float
],
sample_sizes
:
List
[
int
]
)
->
float
:
"""
Calculate the combined RMS for multiple datasets.
"""
assert
len
(
means
)
==
len
(
stds
)
and
len
(
stds
)
==
len
(
sample_sizes
)
means_arr
=
np
.
array
(
means
)
stds_arr
=
np
.
array
(
stds
)
sample_sizes_arr
=
np
.
array
(
sample_sizes
)
cm
=
combined_mean
(
means
,
sample_sizes
)
cv
=
combined_variance
(
means_arr
,
stds_arr
,
sample_sizes_arr
)
# Combined RMS calculation
return
np
.
sqrt
(
cm
**
2
+
cv
)
class
SevenNetMultiModalDataset
(
ConcatDataset
):
def
__init__
(
self
,
modal_dataset_dict
:
Dict
[
str
,
Dataset
],
):
datasets
=
[]
modals
=
[]
for
modal
,
dataset
in
modal_dataset_dict
.
items
():
modals
.
append
(
modal
)
datasets
.
append
(
dataset
)
self
.
modals
=
modals
super
().
__init__
(
datasets
)
def
__getitem__
(
self
,
idx
):
graph
=
super
().
__getitem__
(
idx
)
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
modality
=
self
.
modals
[
dataset_idx
]
graph
[
KEY
.
DATA_MODALITY
]
=
modality
return
graph
def
_modal_wise_property
(
self
,
attribute_name
:
str
):
dct
=
{}
for
modal
,
dataset
in
zip
(
self
.
modals
,
self
.
datasets
):
try
:
if
hasattr
(
dataset
,
attribute_name
):
dct
[
modal
]
=
getattr
(
dataset
,
attribute_name
)
except
AttributeError
:
dct
[
modal
]
=
None
return
dct
@
property
def
dataset_dict
(
self
):
arr
=
{}
for
idx
,
modality
in
enumerate
(
self
.
modals
):
arr
[
modality
]
=
self
.
datasets
[
idx
]
return
arr
@
property
def
species
(
self
):
dct
=
self
.
_modal_wise_property
(
'species'
)
tot
=
set
()
for
sp
in
dct
.
values
():
tot
.
update
(
sp
)
dct
[
'total'
]
=
list
(
tot
)
return
dct
@
property
def
natoms
(
self
):
return
self
.
_modal_wise_property
(
'natoms'
)
@
property
def
per_atom_energy_mean
(
self
):
dct
=
self
.
_modal_wise_property
(
'per_atom_energy_mean'
)
try
:
means
=
[]
sample_sizes
=
[]
for
modality
,
mean
in
dct
.
items
():
means
.
append
(
mean
)
sample_sizes
.
append
(
self
.
statistics
[
modality
][
KEY
.
PER_ATOM_ENERGY
][
'count'
]
)
cm
=
combined_mean
(
means
,
sample_sizes
)
dct
[
'total'
]
=
cm
except
KeyError
:
pass
return
dct
@
property
def
elemwise_reference_energies
(
self
):
# total is not supported (it is expensive and complex, but useless)
return
self
.
_modal_wise_property
(
'elemwise_reference_energies'
)
@
property
def
force_rms
(
self
):
dct
=
self
.
_modal_wise_property
(
'force_rms'
)
try
:
means
=
[]
sample_sizes
=
[]
stds
=
[]
for
modality
in
dct
:
means
.
append
(
self
.
statistics
[
modality
][
KEY
.
FORCE
][
'mean'
])
sample_sizes
.
append
(
self
.
statistics
[
modality
][
KEY
.
FORCE
][
'count'
])
stds
.
append
(
self
.
statistics
[
modality
][
KEY
.
FORCE
][
'std'
])
cm
=
combined_rms
(
means
,
stds
,
sample_sizes
)
dct
[
'total'
]
=
cm
except
KeyError
:
pass
return
dct
@
property
def
per_atom_energy_std
(
self
):
dct
=
self
.
_modal_wise_property
(
'per_atom_energy_std'
)
try
:
means
=
[]
sample_sizes
=
[]
stds
=
[]
for
modality
in
dct
:
means
.
append
(
self
.
statistics
[
modality
][
KEY
.
PER_ATOM_ENERGY
][
'mean'
])
sample_sizes
.
append
(
self
.
statistics
[
modality
][
KEY
.
PER_ATOM_ENERGY
][
'count'
]
)
stds
.
append
(
self
.
statistics
[
modality
][
KEY
.
PER_ATOM_ENERGY
][
'std'
])
cm
=
combined_std
(
means
,
stds
,
sample_sizes
)
dct
[
'total'
]
=
cm
except
KeyError
:
pass
return
dct
@
property
def
avg_num_neigh
(
self
):
dct
=
self
.
_modal_wise_property
(
'avg_num_neigh'
)
try
:
means
=
[]
sample_sizes
=
[]
for
modality
,
mean
in
dct
.
items
():
means
.
append
(
mean
)
sample_sizes
.
append
(
self
.
statistics
[
modality
][
'num_neighbor'
][
'count'
]
)
cm
=
combined_mean
(
means
,
sample_sizes
)
dct
[
'total'
]
=
cm
except
KeyError
:
pass
return
dct
@
property
def
sqrt_avg_num_neigh
(
self
):
avg_nn
=
self
.
avg_num_neigh
return
{
k
:
v
**
0.5
for
k
,
v
in
avg_nn
.
items
()}
@
property
def
statistics
(
self
):
return
self
.
_modal_wise_property
(
'statistics'
)
@
staticmethod
def
as_graph_dataset
(
paths
:
List
[
dict
],
**
graph_dataset_kwargs
,
):
import
sevenn.train.graph_dataset
as
gd
modal_paths
=
_arrange_paths_by_modality
(
paths
)
dataset_dct
=
{}
for
modality
,
paths
in
modal_paths
.
items
():
kwargs
=
deepcopy
(
graph_dataset_kwargs
)
if
(
dataset
:
=
gd
.
from_single_path
(
paths
,
**
kwargs
))
is
None
:
pname
=
kwargs
.
pop
(
'processed_name'
,
'graph'
).
replace
(
'.pt'
,
''
)
dataset
=
gd
.
SevenNetGraphDataset
(
files
=
paths
,
processed_name
=
f
'
{
pname
}
_
{
modality
}
.pt'
,
**
kwargs
,
)
dataset_dct
[
modality
]
=
dataset
return
SevenNetMultiModalDataset
(
dataset_dct
)
def
from_config
(
config
:
Dict
[
str
,
Any
],
working_dir
:
str
=
os
.
getcwd
(),
dataset_keys
:
Optional
[
List
[
str
]]
=
None
,
):
log
=
Logger
()
if
dataset_keys
is
None
:
dataset_keys
=
[
k
for
k
in
config
if
(
k
.
startswith
(
'load_'
)
and
k
.
endswith
(
'_path'
))
]
if
KEY
.
LOAD_TRAINSET
not
in
dataset_keys
:
raise
ValueError
(
f
'
{
KEY
.
LOAD_TRAINSET
}
must be present in config'
)
dataset_args
=
{
'cutoff'
:
config
[
KEY
.
CUTOFF
],
'root'
:
working_dir
,
'process_num_cores'
:
config
.
get
(
KEY
.
PREPROCESS_NUM_CORES
,
1
),
'use_data_weight'
:
config
.
get
(
KEY
.
USE_WEIGHT
,
False
),
**
config
[
KEY
.
DATA_FORMAT_ARGS
],
}
datasets
=
{}
for
dk
in
dataset_keys
:
if
not
(
paths
:
=
config
[
dk
]):
continue
if
isinstance
(
paths
,
str
):
paths
=
[
paths
]
name
=
'_'
.
join
([
nn
.
strip
()
for
nn
in
dk
.
split
(
'_'
)[
1
:
-
1
]])
dataset_args
.
update
({
'processed_name'
:
name
})
datasets
[
name
]
=
SevenNetMultiModalDataset
.
as_graph_dataset
(
paths
,
# type: ignore
**
dataset_args
,
)
train_set
=
datasets
[
'trainset'
]
modals_dataset
=
set
()
chem_species
=
set
()
# print statistics of each dataset
for
name
,
dataset
in
datasets
.
items
():
for
idx
,
modality
in
enumerate
(
dataset
.
modals
):
log
.
bar
()
log
.
writeline
(
f
'
{
name
}
-
{
modality
}
distribution:'
)
log
.
statistic_write
(
dataset
.
statistics
[
modality
])
log
.
format_k_v
(
'# structures (graph)'
,
len
(
dataset
.
datasets
[
idx
]),
write
=
True
)
modals_dataset
.
update
([
modality
])
chem_species
.
update
(
dataset
.
species
[
'total'
])
log
.
bar
()
if
(
modal_map
:
=
config
.
get
(
KEY
.
MODAL_MAP
,
None
))
is
None
:
modals
=
sorted
(
list
(
modals_dataset
))
modal_map
=
{
modal
:
i
for
i
,
modal
in
enumerate
(
modals
)}
config
[
KEY
.
MODAL_MAP
]
=
modal_map
modals
=
list
(
modal_map
.
keys
())
if
not
modals_dataset
.
issubset
(
modal_map
):
raise
ValueError
(
f
'Found modalities in datasets:
{
modals_dataset
}
are not subset of'
+
f
'
{
modals
}
. Use sevenn_cp tool to append/assign modality'
)
log
.
writeline
(
f
'Modalities of this model:
{
modals
}
'
)
config
[
KEY
.
NUM_MODALITIES
]
=
len
(
modal_map
)
# initialize known species from dataset if 'auto'
# sorted to alphabetical order (which is same as before)
chem_keys
=
[
KEY
.
CHEMICAL_SPECIES
,
KEY
.
NUM_SPECIES
,
KEY
.
TYPE_MAP
]
if
all
([
config
[
ck
]
==
'auto'
for
ck
in
chem_keys
]):
# see parse_input.py
log
.
writeline
(
'Known species are obtained from the dataset'
)
config
.
update
(
util
.
chemical_species_preprocess
(
sorted
(
list
(
chem_species
))))
# retrieve shift, scale, conv_denominaotrs from user input (keyword)
init_from_stats_candid
=
[
KEY
.
SHIFT
,
KEY
.
SCALE
,
KEY
.
CONV_DENOMINATOR
]
init_from_stats
=
[
k
for
k
in
init_from_stats_candid
if
isinstance
(
config
[
k
],
str
)
]
for
k
in
init_from_stats
:
input
=
config
[
k
]
if
not
hasattr
(
train_set
,
input
):
raise
NotImplementedError
(
input
)
modal_stat
=
getattr
(
train_set
,
input
)
try
:
if
k
==
KEY
.
CONV_DENOMINATOR
and
'total'
in
modal_stat
:
# conv_denominator is not modal-wise
var
=
modal_stat
[
'total'
]
elif
k
==
KEY
.
SHIFT
and
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]:
modal_stat
.
pop
(
'total'
,
None
)
var
=
modal_stat
elif
k
==
KEY
.
SHIFT
and
not
config
[
KEY
.
USE_MODAL_WISE_SHIFT
]:
var
=
modal_stat
[
'total'
]
elif
k
==
KEY
.
SCALE
and
config
[
KEY
.
USE_MODAL_WISE_SCALE
]:
modal_stat
.
pop
(
'total'
,
None
)
var
=
modal_stat
elif
k
==
KEY
.
SCALE
and
not
config
[
KEY
.
USE_MODAL_WISE_SCALE
]:
var
=
modal_stat
[
'total'
]
else
:
raise
NotImplementedError
(
f
'Failed to init
{
k
}
from statistics'
)
except
KeyError
as
e
:
if
e
.
args
[
0
]
==
'total'
:
raise
NotImplementedError
(
f
'
{
k
}
:
{
input
}
does not support total statistics. '
+
f
'Set use_modal_wise_
{
k
}
True or specify numbers manually'
)
else
:
raise
e
config
.
update
({
k
:
var
})
log
.
writeline
(
f
'
{
k
}
is obtained from statistics'
)
return
datasets
mace-bench/3rdparty/SevenNet/sevenn/train/optim.py
0 → 100644
View file @
ca86f720
import
torch.nn
as
nn
import
torch.optim.lr_scheduler
as
scheduler
from
torch.optim
import
adagrad
,
adam
,
adamw
,
radam
,
sgd
optim_dict
=
{
'sgd'
:
sgd
.
SGD
,
'adagrad'
:
adagrad
.
Adagrad
,
'adam'
:
adam
.
Adam
,
'adamw'
:
adamw
.
AdamW
,
'radam'
:
radam
.
RAdam
,
}
scheduler_dict
=
{
'steplr'
:
scheduler
.
StepLR
,
'multisteplr'
:
scheduler
.
MultiStepLR
,
'exponentiallr'
:
scheduler
.
ExponentialLR
,
'cosineannealinglr'
:
scheduler
.
CosineAnnealingLR
,
'reducelronplateau'
:
scheduler
.
ReduceLROnPlateau
,
'linearlr'
:
scheduler
.
LinearLR
,
}
loss_dict
=
{
'mse'
:
nn
.
MSELoss
,
'huber'
:
nn
.
HuberLoss
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment