Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c46b6925
Commit
c46b6925
authored
Jun 18, 2025
by
Max Rietmann
Browse files
Applied new formatting
parent
1ea5c4ca
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
769 additions
and
757 deletions
+769
-757
.clang-format
.clang-format
+3
-3
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+246
-245
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+155
-155
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+2
-2
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+3
-3
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+184
-179
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+176
-170
No files found.
.clang-format
View file @
c46b6925
---
BasedOnStyle: Webkit
IndentWidth:
2
IndentWidth:
4
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignTrailingComments: true
...
...
@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth:
2
ContinuationIndentWidth:
2
ConstructorInitializerIndentWidth:
4
ContinuationIndentWidth:
4
Cpp11BracedListStyle: true
FixNamespaceComments: true
NamespaceIndentation: All
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
c46b6925
...
...
@@ -51,7 +51,7 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
...
...
@@ -70,8 +70,9 @@
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
...
...
@@ -82,7 +83,7 @@ public:
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
}
private:
private:
std
::
string
label_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
c46b6925
...
...
@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define NNZ_TRESH (32)
...
...
torch_harmonics/csrc/attention/attention_interface.cu
View file @
c46b6925
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
c46b6925
...
...
@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
...
...
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
c46b6925
...
...
@@ -140,7 +140,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
...
...
@@ -173,24 +173,24 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
switch
(
pscale
)
{
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
return
;
...
...
@@ -231,36 +231,41 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
...
...
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
c46b6925
...
...
@@ -55,7 +55,8 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
REAL_T
__reg
[
ELXTH
]
=
{
0
};
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
int
col_prev
=
cols
[
soff
];
...
...
@@ -145,7 +146,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
...
...
@@ -172,11 +173,11 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
disco_fwd_blk_k
<
NTH
,
ELXTH
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
return
;
...
...
@@ -218,36 +219,41 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp, torch::Tensor roff_idx, torch::T
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment