Filters atoms based on the provided filtering type and element subset.
Parameters:
atoms (ase.Atoms): The atoms object to filter.
element_subset (list): The list of elements to consider during filtering.
filtering_type (FilteringType): The type of filtering to apply.
Can be one of the following `FilteringType` enum members:
- `FilteringType.NONE`: No filtering is applied.
- `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present.
- `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise.
- `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements.
Returns:
bool: True if the atoms pass the filter, False otherwise.
"""
iffiltering_type==FilteringType.NONE:
returnTrue
iffiltering_type==FilteringType.COMBINATIONS:
atom_symbols=np.unique(atoms.symbols)
returnall(
xinelement_subsetforxinatom_symbols
)# atoms must *only* contain elements in the subset
iffiltering_type==FilteringType.EXCLUSIVE:
atom_symbols=set(list(atoms.symbols))
returnatom_symbols==set(element_subset)
iffiltering_type==FilteringType.INCLUSIVE:
atom_symbols=np.unique(atoms.symbols)
returnall(
xinatom_symbolsforxinelement_subset
)# atoms must *at least* contain elements in the subset
raiseValueError(
f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}."
"Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file."
)
args.multiheads_finetuning=False
ifargs.multiheads_finetuning:
assert(
args.E0s!="average"
),"average atomic energies cannot be used for multiheads finetuning"
# check that the foundation model has a single head, if not, use the first head
ifnotargs.force_mh_ft_lr:
logging.info(
"Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True."
)
args.lr=0.0001
args.ema=True
args.ema_decay=0.99999
logging.info(
"Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True"
)
ifhasattr(model_foundation,"heads"):
iflen(model_foundation.heads)>1:
logging.warning(
"Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head."
)
model_foundation=remove_pt_head(
model_foundation,args.foundation_head
)
else:
args.multiheads_finetuning=False
ifargs.headsisnotNone:
args.heads=ast.literal_eval(args.heads)
for_,head_dictinargs.heads.items():
# priority is global args < head property_key values < head info_keys+arrays_keys
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: "+"{"+", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}"forzinhead_config.z_table.zs])+"}")
exceptKeyErrorase:
raiseKeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number")frome
# Load datasets for each head, supporting multiple files per head
valid_sets={head:[]forheadinheads}
train_sets={head:[]forheadinheads}
forhead_configinhead_configs:
train_datasets=[]
logging.info(f"Processing datasets for head '{head_config.head_name}'")