Unverified Commit b314360c authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

[bugfix] dtype fix when remove thrust (#588)

* fix

* rm
parent 00548245
......@@ -754,9 +754,10 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
// buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto d_options = torch::dtype(torch::kFloat32).device(coordinates_t.device());
auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device());
float inf = std::numeric_limits<float>::infinity();
Tensor tensor_Rij = torch::full(sizeof(PairDist<float>) / sizeof(float) * total_natom_pairs, inf, d_options);
Tensor tensor_Rij =
torch::full(sizeof(PairDist<float>) / sizeof(float) * total_natom_pairs, inf, d_options.dtype(torch::kFloat32));
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr();
// buffer to store all the pairwise distance that is needed for Radial AEV
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment