Unverified Commit 7059e9a6 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Add a function to recast all buffer tensors (#473)

* Add a function to recast all buffer tensors

* flake8
parent ad7cad50
...@@ -95,6 +95,11 @@ class BuiltinNet(torch.nn.Module): ...@@ -95,6 +95,11 @@ class BuiltinNet(torch.nn.Module):
self.neural_networks = neurochem.load_model_ensemble( self.neural_networks = neurochem.load_model_ensemble(
self.species, self.ensemble_prefix, self.ensemble_size) self.species, self.ensemble_prefix, self.ensemble_size)
@torch.jit.export
def _recast_long_buffers(self):
self.species_converter.conv_tensor = self.species_converter.conv_tensor.to(dtype=torch.long)
self.aev_computer.triu_index = self.aev_computer.triu_index.to(dtype=torch.long)
def forward(self, species_coordinates: Tuple[Tensor, Tensor], def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment