Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
1ea5c4ca
Commit
1ea5c4ca
authored
Jun 16, 2025
by
Max Rietmann
Browse files
Merged formatting
parents
68e7d0fa
373f9b0b
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
403 additions
and
643 deletions
+403
-643
.clang-format-ignore
.clang-format-ignore
+1
-1
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+7
-12
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+1
-1
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+73
-88
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+2
-2
torch_harmonics/csrc/disco/disco.h
torch_harmonics/csrc/disco/disco.h
+1
-1
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+10
-23
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+145
-242
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+121
-209
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+39
-60
torch_harmonics/csrc/disco/disco_interface.cu
torch_harmonics/csrc/disco/disco_interface.cu
+3
-4
No files found.
.clang-format-ignore
View file @
1ea5c4ca
torch_harmonics/csrc/attention/attention.cuh
View file @
1ea5c4ca
...
@@ -36,16 +36,11 @@
...
@@ -36,16 +36,11 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
psi_col_idx
,
int
nlon_out
);
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
dy
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
1ea5c4ca
...
@@ -51,7 +51,7 @@
...
@@ -51,7 +51,7 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call) \
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
1ea5c4ca
...
@@ -45,36 +45,39 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
...
@@ -45,36 +45,39 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define THREADS (64)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define NNZ_TRESH (32)
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
cudaError_t err = call; \
if( cudaSuccess != err) { \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
}}
} \
}
#define CHECK_ERROR(errorMessage) { \
#define CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
cudaError_t err = cudaGetLastError(); \
if
(
cudaSuccess != err) { \
if
(
cudaSuccess != err) {
\
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n",
errorMessage, __FILE__, __LINE__,
\
errorMessage, __FILE__, __LINE__,
cudaGetErrorString(
err)
); \
cudaGetErrorString(err));
\
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
}}
} \
}
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
...
@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
...
@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
return
sum
;
return
sum
;
}
}
// one warp per (ho,wo)
// one warp per (ho,wo)
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
__launch_bounds__
(
BDIM_X
)
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
extern
__shared__
float
sh
[];
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
return
;
}
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
#if 0
#if 0
// useless read, y is always zeroed before kernel is called
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
shy[chan] = y[batchId][chan][ho][wo];
...
@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
...
@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
float
qdotk_max
=
-
FLT_MAX
;
float
qdotk_max
=
-
FLT_MAX
;
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
kx
[
batchId
][
chan
][
hi
][
wip
];
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk
=
__warp_sum_cub
(
qdotk
);
...
@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
...
@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
}
}
qdotk_max
=
qdotk_max_tmp
;
qdotk_max
=
qdotk_max_tmp
;
}
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
return
;
return
;
}
}
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
vx
,
int
nlon_out
)
at
::
Tensor
qy
,
{
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
vx
);
...
@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
// coalesced memory access
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
nvtxRangePop
();
nvtxRangePop
();
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
float
milliseconds
=
0
;
...
@@ -243,9 +230,8 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -243,9 +230,8 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_kernel
<
THREADS
>
s2_attention_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
<<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
...
@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
return
y
;
return
y
;
}
}
torch_harmonics/csrc/attention/attention_interface.cu
View file @
1ea5c4ca
...
@@ -31,8 +31,8 @@
...
@@ -31,8 +31,8 @@
#include "attention.cuh"
#include "attention.cuh"
#include <torch/extension.h>
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
}
}
torch_harmonics/csrc/disco/disco.h
View file @
1ea5c4ca
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
1ea5c4ca
...
@@ -36,32 +36,19 @@
...
@@ -36,32 +36,19 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x)
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define MIN_THREADS (64)
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
#define ELXTH_MAX (32)
// forward kernel
// forward kernel
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
// backward kernel
// backward kernel
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
\ No newline at end of file
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
1ea5c4ca
...
@@ -31,23 +31,12 @@
...
@@ -31,23 +31,12 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
int
ELXTH
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
typename
REAL_T
>
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
...
@@ -55,36 +44,32 @@ __device__ void disco_bwd_d(const int Hi,
...
@@ -55,36 +44,32 @@ __device__ void disco_bwd_d(const int Hi,
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
out
+=
bidy
*
Ho
*
Wo
;
out
+=
bidy
*
Ho
*
Wo
;
// align to larger supported fp type
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
// copy current inp row in regs
// copy current inp row in regs
REAL_T
__reg
[
ELXTH
];
REAL_T
__reg
[
ELXTH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
// reset shared row up to Wo+2, remaining
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// will be written to but never copied to
// global mem
// global mem
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
__sh
[
i
][
j
+
tid
]
=
0
;
}
}
}
__syncthreads
();
__syncthreads
();
...
@@ -94,7 +79,7 @@ __device__ void disco_bwd_d(const int Hi,
...
@@ -94,7 +79,7 @@ __device__ void disco_bwd_d(const int Hi,
int
w_prev
=
col_prev
%
Wo
;
int
w_prev
=
col_prev
%
Wo
;
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
...
@@ -104,16 +89,16 @@ __device__ void disco_bwd_d(const int Hi,
...
@@ -104,16 +89,16 @@ __device__ void disco_bwd_d(const int Hi,
// to shmem;
// to shmem;
// we read a col that points to a new output
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
// row if (col / Wo) > (col_prev / Wo)
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
__syncthreads
();
__syncthreads
();
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
}
}
}
}
...
@@ -124,15 +109,15 @@ __device__ void disco_bwd_d(const int Hi,
...
@@ -124,15 +109,15 @@ __device__ void disco_bwd_d(const int Hi,
w_prev
=
col
%
Wo
;
w_prev
=
col
%
Wo
;
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
}
}
// to avoid race conditions on __sh[]
// to avoid race conditions on __sh[]
...
@@ -142,127 +127,78 @@ __device__ void disco_bwd_d(const int Hi,
...
@@ -142,127 +127,78 @@ __device__ void disco_bwd_d(const int Hi,
__syncthreads
();
__syncthreads
();
// write last row
// write last row
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
}
}
}
}
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
template
<
int
BDIM_X
,
__global__
int
ELXTH
,
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
int
PSCALE
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
typename
REAL_T
>
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
__global__
__launch_bounds__
(
BDIM_X
)
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
void
disco_bwd_blk_k
(
const
int
Hi
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
if
constexpr
(
PSCALE
!=
0
)
{
const
int
Wo
,
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
const
int
pscale
,
}
else
{
const
int64_t
*
__restrict__
roff
,
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
const
int64_t
*
__restrict__
kers
,
}
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
template
<
int
NTH
,
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
int
ELXTH
,
typename
REAL_T
>
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
static
void
launch_kernel
(
int
BC
,
if
(
NTH
*
ELXTH
>=
Wi
)
{
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wo
/
Wi
;
const
int
pscale
=
Wo
/
Wi
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
switch
(
pscale
)
{
switch
(
pscale
)
{
case
1
:
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
K
,
Ho
,
Wo
,
pscale
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
case
2
:
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
K
,
Ho
,
Wo
,
pscale
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
case
3
:
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
K
,
Ho
,
Wo
,
pscale
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
default:
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
K
,
Ho
,
Wo
,
pscale
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
}
else
{
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
stream
);
}
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
torch
::
Tensor
roff_idx
,
{
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
...
@@ -289,85 +225,52 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
...
@@ -289,85 +225,52 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
64
,
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
{
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
fprintf
(
stderr
,
1024
*
ELXTH_MAX
);
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
return
out
;
return
out
;
}
}
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//}
//}
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
1ea5c4ca
...
@@ -31,23 +31,12 @@
...
@@ -31,23 +31,12 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
int
ELXTH
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
typename
REAL_T
>
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
...
@@ -55,13 +44,13 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -55,13 +44,13 @@ __device__ void disco_fwd_d(const int Hi,
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
Hi
*
Wi
;
inp
+=
bidy
*
Hi
*
Wi
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
REAL_T
__reg
[
ELXTH
]
=
{
0
};
REAL_T
__reg
[
ELXTH
]
=
{
0
};
...
@@ -75,16 +64,16 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -75,16 +64,16 @@ __device__ void disco_fwd_d(const int Hi,
int
w_prev
=
col_prev
%
Wi
;
int
w_prev
=
col_prev
%
Wi
;
// copy current inp row in shmem
// copy current inp row in shmem
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads
();
__syncthreads
();
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
...
@@ -94,27 +83,27 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -94,27 +83,27 @@ __device__ void disco_fwd_d(const int Hi,
// to shmem;
// to shmem;
// checks whether (h_prev < h) with:
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
// (col >= col_prev - (col_prev % Wi) + Wi)
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
col_prev
=
col
;
col_prev
=
col
;
h_prev
=
col
/
Wi
;
h_prev
=
col
/
Wi
;
w_prev
=
col
%
Wi
;
w_prev
=
col
%
Wi
;
__syncthreads
();
__syncthreads
();
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
__syncthreads
();
__syncthreads
();
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
// original lines:
// original lines:
//
//
...
@@ -136,17 +125,16 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -136,17 +125,16 @@ __device__ void disco_fwd_d(const int Hi,
//
//
// with NUM_REM = BDIM_X*ELXTH - Wo
// with NUM_REM = BDIM_X*ELXTH - Wo
const
int
wpp
=
w
+
pscale
*
pp
;
const
int
wpp
=
w
+
pscale
*
pp
;
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
if
(
pp
>=
Wo
)
break
;
if
(
pp
>=
Wo
)
break
;
out
[
pp
]
=
__reg
[
i
];
out
[
pp
]
=
__reg
[
i
];
...
@@ -155,92 +143,48 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -155,92 +143,48 @@ __device__ void disco_fwd_d(const int Hi,
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
template
<
int
BDIM_X
,
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
int
ELXTH
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
typename
REAL_T
>
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
__global__
__launch_bounds__
(
BDIM_X
)
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
void
disco_fwd_blk_k
(
const
int
Hi
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
template
<
int
NTH
,
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
int
ELXTH
,
if
(
NTH
*
ELXTH
>=
Wo
)
{
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
dim3
grid
(
nrows
,
BC
);
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wi
/
Wo
;
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
disco_fwd_blk_k
<
NTH
,
ELXTH
>
K
,
Ho
,
Wo
,
pscale
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
stream
);
}
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
{
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
...
@@ -267,81 +211,49 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
...
@@ -267,81 +211,49 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// pick the correct launch config
// pick the correct launch config
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
64
,
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}));
}
}
else
{
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
fprintf
(
stderr
,
1024
*
ELXTH_MAX
);
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
return
out
;
return
out
;
}
}
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
1ea5c4ca
...
@@ -30,31 +30,21 @@
...
@@ -30,31 +30,21 @@
#include "disco.h"
#include "disco.h"
template
<
typename
REAL_T
>
template
<
typename
REAL_T
>
void
preprocess_psi_kernel
(
int64_t
nnz
,
void
preprocess_psi_kernel
(
int64_t
nnz
,
int64_t
K
,
int64_t
Ho
,
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
K
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
int64_t
Ho
,
{
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
{
int64_t
*
Koff
=
new
int64_t
[
K
];
int64_t
*
Koff
=
new
int64_t
[
K
];
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
Koff
[
i
]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
Koff
[
ker_h
[
i
]]
++
;
}
int64_t
prev
=
Koff
[
0
];
int64_t
prev
=
Koff
[
0
];
Koff
[
0
]
=
0
;
Koff
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
int64_t
save
=
Koff
[
i
];
int64_t
save
=
Koff
[
i
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
prev
=
save
;
prev
=
save
;
}
}
...
@@ -63,7 +53,7 @@ void preprocess_psi_kernel(int64_t nnz,
...
@@ -63,7 +53,7 @@ void preprocess_psi_kernel(int64_t nnz,
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
off
=
Koff
[
ker
]
++
;
const
int64_t
off
=
Koff
[
ker
]
++
;
...
@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
...
@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
col_sort
[
off
]
=
col_h
[
i
];
col_sort
[
off
]
=
col_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
}
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
ker_h
[
i
]
=
ker_sort
[
i
];
ker_h
[
i
]
=
ker_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
}
}
delete
[]
Koff
;
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
delete
[]
val_sort
;
// compute rows offsets
// compute rows offsets
nrows
=
1
;
nrows
=
1
;
roff_h
[
0
]
=
0
;
roff_h
[
0
]
=
0
;
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
roff_h
[
nrows
++
]
=
i
;
roff_h
[
nrows
++
]
=
i
;
if
(
nrows
>
Ho
*
K
)
{
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
int64_t
(
Ho
)
*
K
);
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
}
}
...
@@ -106,13 +95,9 @@ void preprocess_psi_kernel(int64_t nnz,
...
@@ -106,13 +95,9 @@ void preprocess_psi_kernel(int64_t nnz,
return
;
return
;
}
}
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
const
int64_t
Ho
,
{
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
...
@@ -123,33 +108,27 @@ torch::Tensor preprocess_psi(const int64_t K,
...
@@ -123,33 +108,27 @@ torch::Tensor preprocess_psi(const int64_t K,
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
int64_t
nrows
;
//float *val_h = val.data_ptr<float>();
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]{
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
ker_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
row_h
,
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
}));
// create output tensor
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
roff_out_h
[
i
]
=
roff_h
[
i
];
delete
[]
roff_h
;
}
delete
[]
roff_h
;
return
roff_idx
;
return
roff_idx
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
}
}
torch_harmonics/csrc/disco/disco_interface.cu
View file @
1ea5c4ca
...
@@ -31,9 +31,8 @@
...
@@ -31,9 +31,8 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment