Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c877cda6
Commit
c877cda6
authored
Jul 21, 2025
by
Thorsten Kurth
Browse files
making disco helpers GPU ready
parent
00064117
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
2 deletions
+18
-2
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+18
-2
No files found.
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
c877cda6
...
@@ -104,6 +104,16 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
...
@@ -104,6 +104,16 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
CHECK_INPUT_TENSOR
(
val
);
// get the input device and make sure all tensors are on the same device
auto
device
=
ker_idx
.
device
();
TORCH_INTERNAL_ASSERT
(
device
.
type
()
==
row_idx
.
device
().
type
()
&&
(
device
.
type
()
==
col_idx
.
device
().
type
())
&&
(
device
.
type
()
==
val
.
device
().
type
()));
// move to cpu
ker_idx
=
ker_idx
.
to
(
torch
::
kCPU
);
row_idx
=
row_idx
.
to
(
torch
::
kCPU
);
col_idx
=
col_idx
.
to
(
torch
::
kCPU
);
val
=
val
.
to
(
torch
::
kCPU
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
...
@@ -117,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
...
@@ -117,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
}));
}));
// create output tensor
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
row_idx
.
options
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
delete
[]
roff_h
;
// move to original device
ker_idx
=
ker_idx
.
to
(
device
);
row_idx
=
row_idx
.
to
(
device
);
col_idx
=
col_idx
.
to
(
device
);
val
=
val
.
to
(
device
);
roff_idx
=
roff_idx
.
to
(
device
);
return
roff_idx
;
return
roff_idx
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment