Commit 00064117 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

small device fix

parent 4aaff021
...@@ -110,7 +110,6 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke ...@@ -110,7 +110,6 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
int64_t *col_h = col_idx.data_ptr<int64_t>(); int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho * K + 1]; int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows; int64_t nrows;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] { AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h, preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
......
...@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", ...@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
# compute the tensor P^m_n: # compute the tensor P^m_n:
nmax = max(mmax,lmax) nmax = max(mmax,lmax)
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False) vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, device=x.device, requires_grad=False)
norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi) norm_factor = 1.0 if norm == "ortho" else math.sqrt(4 * math.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor norm_factor = 1.0 / norm_factor if inverse else norm_factor
# initial values to start the recursion # initial values to start the recursion
vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi) vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment