Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
373f9b0b
Commit
373f9b0b
authored
Jun 13, 2025
by
Thorsten Kurth
Browse files
formatting
parent
ebc122eb
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
470 additions
and
721 deletions
+470
-721
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+7
-12
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+69
-80
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+73
-88
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+2
-2
torch_harmonics/csrc/disco/disco.h
torch_harmonics/csrc/disco/disco.h
+1
-1
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+10
-23
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+145
-242
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+121
-209
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+39
-60
torch_harmonics/csrc/disco/disco_interface.cu
torch_harmonics/csrc/disco/disco_interface.cu
+3
-4
No files found.
torch_harmonics/csrc/attention/attention.cuh
View file @
373f9b0b
...
@@ -36,16 +36,11 @@
...
@@ -36,16 +36,11 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
psi_col_idx
,
int
nlon_out
);
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
dy
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
at
::
Tensor
quad_weights
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -51,28 +51,32 @@
...
@@ -51,28 +51,32 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
cudaError_t err = call; \
{ \
if( cudaSuccess != err) { \
cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
}}
} \
}
#endif
#endif
#include <iostream>
#include <iostream>
#include <chrono>
#include <chrono>
#include <string>
#include <string>
class
ScopeTimer
{
class
ScopeTimer
{
public:
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
{
}
~
ScopeTimer
()
{
~
ScopeTimer
()
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
...
@@ -83,20 +87,19 @@ private:
...
@@ -83,20 +87,19 @@ private:
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
};
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads
...
@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
...
@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
// memory access.
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
__launch_bounds__
(
BDIM_X
)
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
void
s2_attention_bwd_dkvq_kernel
(
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
int
num_channels
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
int
nlon_in
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
int
nlat_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
int
nlon_out
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
{
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
batchId
=
blockIdx
.
y
;
...
@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
...
@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads
();
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
// First pass: find qdotk_max
// First pass: find qdotk_max
...
@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X)
...
@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X)
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
}
}
...
@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X)
...
@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
}
// Third pass: accumulate gradients for k and v
// Third pass: accumulate gradients for k and v
...
@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X)
...
@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X)
}
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
at
::
Tensor
qy
,
{
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
vx
);
...
@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
...
@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
...
@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
...
@@ -310,20 +303,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
...
@@ -310,20 +303,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
...
@@ -333,15 +324,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
...
@@ -333,15 +324,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// for(auto& stride : dydk.strides()) {
...
@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
...
@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// delete permute_output_timer;
// nvtxRangePop();
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
}
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -45,39 +45,42 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
...
@@ -45,39 +45,42 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define THREADS (64)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define NNZ_TRESH (32)
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
#define CHECK_CUDA(call) \
cudaError_t err = call; \
{ \
if( cudaSuccess != err) { \
cudaError_t err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
}}
} \
}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
#define CHECK_ERROR(errorMessage) \
if( cudaSuccess != err) { \
{ \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
cudaError_t err = cudaGetLastError(); \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
if (cudaSuccess != err) { \
exit(EXIT_FAILURE); \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
}}
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
static
__device__
float
__warp_sum
(
float
val
)
{
} \
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
}
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads
...
@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
...
@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
return
sum
;
return
sum
;
}
}
// one warp per (ho,wo)
// one warp per (ho,wo)
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
__launch_bounds__
(
BDIM_X
)
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
void
s2_attention_kernel
(
int
num_channels
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
int
nlon_in
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
int
nlat_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
int
nlon_out
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
{
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
extern
__shared__
float
sh
[];
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
#if 0
#if 0
// useless read, y is always zeroed before kernel is called
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
shy[chan] = y[batchId][chan][ho][wo];
...
@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
...
@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
float
qdotk_max
=
-
FLT_MAX
;
float
qdotk_max
=
-
FLT_MAX
;
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
kx
[
batchId
][
chan
][
hi
][
wip
];
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk
=
__warp_sum_cub
(
qdotk
);
...
@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
...
@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
}
}
qdotk_max
=
qdotk_max_tmp
;
qdotk_max
=
qdotk_max_tmp
;
}
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
return
;
return
;
}
}
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
at
::
Tensor
vx
,
int
nlon_out
)
at
::
Tensor
qy
,
{
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
vx
);
...
@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
// coalesced memory access
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
nvtxRangePop
();
nvtxRangePop
();
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
float
milliseconds
=
0
;
...
@@ -243,15 +230,14 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -243,15 +230,14 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_kernel
<
THREADS
>
s2_attention_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
<<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
...
@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
return
y
;
return
y
;
}
}
torch_harmonics/csrc/attention/attention_interface.cu
View file @
373f9b0b
...
@@ -31,8 +31,8 @@
...
@@ -31,8 +31,8 @@
#include "attention.cuh"
#include "attention.cuh"
#include <torch/extension.h>
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
}
}
torch_harmonics/csrc/disco/disco.h
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
...
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -36,32 +36,19 @@
...
@@ -36,32 +36,19 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x)
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define MIN_THREADS (64)
#define MIN_THREADS (64)
#define ELXTH_MAX
(32)
#define ELXTH_MAX (32)
// forward kernel
// forward kernel
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
// backward kernel
// backward kernel
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
\ No newline at end of file
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -31,239 +31,175 @@
...
@@ -31,239 +31,175 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
int
ELXTH
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
typename
REAL_T
>
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
out
+=
bidy
*
Ho
*
Wo
;
out
+=
bidy
*
Ho
*
Wo
;
// align to larger supported fp type
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
// copy current inp row in regs
// copy current inp row in regs
REAL_T
__reg
[
ELXTH
];
REAL_T
__reg
[
ELXTH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
// reset shared row up to Wo+2, remaining
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// will be written to but never copied to
// global mem
// global mem
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
__sh
[
i
][
j
+
tid
]
=
0
;
}
}
}
__syncthreads
();
__syncthreads
();
int
col_prev
=
cols
[
soff
];
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wo
;
int
h_prev
=
col_prev
/
Wo
;
int
w_prev
=
col_prev
%
Wo
;
int
w_prev
=
col_prev
%
Wo
;
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// leading to a new row of inp then copy it
// to shmem;
// to shmem;
// we read a col that points to a new output
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
// row if (col / Wo) > (col_prev / Wo)
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
__syncthreads
();
__syncthreads
();
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
}
}
}
}
__syncthreads
();
__syncthreads
();
col_prev
=
col
;
col_prev
=
col
;
h_prev
=
col
/
Wo
;
h_prev
=
col
/
Wo
;
w_prev
=
col
%
Wo
;
w_prev
=
col
%
Wo
;
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
}
}
// to avoid race conditions on __sh[]
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
// among consecutive iterations along nz
__syncthreads
();
__syncthreads
();
}
}
__syncthreads
();
__syncthreads
();
// write last row
// write last row
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
}
}
}
}
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
template
<
int
NTH
,
const
int
pscale
=
Wo
/
Wi
;
int
ELXTH
,
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
switch
(
pscale
)
{
int
Hi
,
case
1
:
int
Wi
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
int
K
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
int
Ho
,
break
;
int
Wo
,
case
2
:
int64_t
nrows
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
int64_t
*
roff_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
int64_t
*
ker_d
,
break
;
int64_t
*
row_d
,
case
3
:
int64_t
*
col_d
,
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
REAL_T
*
val_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
REAL_T
*
inp_d
,
break
;
REAL_T
*
out_d
,
default:
cudaStream_t
stream
)
{
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wo
/
Wi
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
switch
(
pscale
)
{
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
...
@@ -287,87 +223,54 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
...
@@ -287,87 +223,54 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
// get stream
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
64
,
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
{
inp
.
data_ptr
<
scalar_t
>
(),
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
out
.
data_ptr
<
scalar_t
>
(),
1024
*
ELXTH_MAX
);
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
return
out
;
return
out
;
}
}
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//}
//}
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -31,101 +31,90 @@
...
@@ -31,101 +31,90 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
int
ELXTH
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
typename
REAL_T
>
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
Hi
*
Wi
;
inp
+=
bidy
*
Hi
*
Wi
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
REAL_T
__reg
[
ELXTH
]
=
{
0
};
REAL_T
__reg
[
ELXTH
]
=
{
0
};
// align to larger supported fp type
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
int
col_prev
=
cols
[
soff
];
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wi
;
int
h_prev
=
col_prev
/
Wi
;
int
w_prev
=
col_prev
%
Wi
;
int
w_prev
=
col_prev
%
Wi
;
// copy current inp row in shmem
// copy current inp row in shmem
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads
();
__syncthreads
();
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// leading to a new row of inp then copy it
// to shmem;
// to shmem;
// checks whether (h_prev < h) with:
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
// (col >= col_prev - (col_prev % Wi) + Wi)
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
col_prev
=
col
;
col_prev
=
col
;
h_prev
=
col
/
Wi
;
h_prev
=
col
/
Wi
;
w_prev
=
col
%
Wi
;
w_prev
=
col
%
Wi
;
__syncthreads
();
__syncthreads
();
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
__syncthreads
();
__syncthreads
();
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
// original lines:
// original lines:
//
//
// if (pp >= Wo) break;
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
// const int wpp = (w + pscale*pp) % Wi;
//
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
// replicate the current inp row and avoid the costly mod
//
//
// also, to avoid the conditional, sh can be extended to
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
// cover the maximum location accessed during this loop
//
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
...
@@ -135,113 +124,68 @@ __device__ void disco_fwd_d(const int Hi,
...
@@ -135,113 +124,68 @@ __device__ void disco_fwd_d(const int Hi,
// = 2*Wi + ppscale*NUM_REM
// = 2*Wi + ppscale*NUM_REM
//
//
// with NUM_REM = BDIM_X*ELXTH - Wo
// with NUM_REM = BDIM_X*ELXTH - Wo
const
int
wpp
=
w
+
pscale
*
pp
;
const
int
wpp
=
w
+
pscale
*
pp
;
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
if
(
pp
>=
Wo
)
break
;
if
(
pp
>=
Wo
)
break
;
out
[
pp
]
=
__reg
[
i
];
out
[
pp
]
=
__reg
[
i
];
}
}
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
template
<
int
BDIM_X
,
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
int
ELXTH
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
typename
REAL_T
>
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
__global__
__launch_bounds__
(
BDIM_X
)
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
void
disco_fwd_blk_k
(
const
int
Hi
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
int
Wi
,
{
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
template
<
int
NTH
,
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
int
ELXTH
,
typename
REAL_T
>
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
static
void
launch_kernel
(
int
BC
,
if
(
NTH
*
ELXTH
>=
Wo
)
{
int
Hi
,
dim3
grid
(
nrows
,
BC
);
int
Wi
,
int
K
,
const
int
pscale
=
Wi
/
Wo
;
int
Ho
,
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
int
Wo
,
int64_t
nrows
,
disco_fwd_blk_k
<
NTH
,
ELXTH
>
int64_t
*
roff_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
int64_t
*
ker_d
,
}
else
{
int64_t
*
row_d
,
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
int64_t
*
col_d
,
stream
);
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
...
@@ -265,83 +209,51 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
...
@@ -265,83 +209,51 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
// get stream
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// pick the correct launch config
// pick the correct launch config
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
64
,
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
}));
val
.
data_ptr
<
scalar_t
>
(),
}
else
{
inp
.
data_ptr
<
scalar_t
>
(),
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
out
.
data_ptr
<
scalar_t
>
(),
1024
*
ELXTH_MAX
);
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
return
out
;
return
out
;
}
}
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -30,40 +30,30 @@
...
@@ -30,40 +30,30 @@
#include "disco.h"
#include "disco.h"
template
<
typename
REAL_T
>
template
<
typename
REAL_T
>
void
preprocess_psi_kernel
(
int64_t
nnz
,
void
preprocess_psi_kernel
(
int64_t
nnz
,
int64_t
K
,
int64_t
Ho
,
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
K
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
int64_t
Ho
,
{
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
{
int64_t
*
Koff
=
new
int64_t
[
K
];
int64_t
*
Koff
=
new
int64_t
[
K
];
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
Koff
[
i
]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
Koff
[
ker_h
[
i
]]
++
;
}
int64_t
prev
=
Koff
[
0
];
int64_t
prev
=
Koff
[
0
];
Koff
[
0
]
=
0
;
Koff
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
int64_t
save
=
Koff
[
i
];
int64_t
save
=
Koff
[
i
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
prev
=
save
;
prev
=
save
;
}
}
int64_t
*
ker_sort
=
new
int64_t
[
nnz
];
int64_t
*
ker_sort
=
new
int64_t
[
nnz
];
int64_t
*
row_sort
=
new
int64_t
[
nnz
];
int64_t
*
row_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
off
=
Koff
[
ker
]
++
;
const
int64_t
off
=
Koff
[
ker
]
++
;
...
@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
...
@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
col_sort
[
off
]
=
col_h
[
i
];
col_sort
[
off
]
=
col_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
}
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
ker_h
[
i
]
=
ker_sort
[
i
];
ker_h
[
i
]
=
ker_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
}
}
delete
[]
Koff
;
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
delete
[]
val_sort
;
// compute rows offsets
// compute rows offsets
nrows
=
1
;
nrows
=
1
;
roff_h
[
0
]
=
0
;
roff_h
[
0
]
=
0
;
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
roff_h
[
nrows
++
]
=
i
;
roff_h
[
nrows
++
]
=
i
;
if
(
nrows
>
Ho
*
K
)
{
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
int64_t
(
Ho
)
*
K
);
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
exit
(
EXIT_FAILURE
);
exit
(
EXIT_FAILURE
);
}
}
}
}
...
@@ -106,50 +95,40 @@ void preprocess_psi_kernel(int64_t nnz,
...
@@ -106,50 +95,40 @@ void preprocess_psi_kernel(int64_t nnz,
return
;
return
;
}
}
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
CHECK_INPUT_TENSOR
(
val
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
int64_t
nrows
;
//float *val_h = val.data_ptr<float>();
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]{
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
ker_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
row_h
,
}));
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
// create output tensor
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
roff_out_h
[
i
]
=
roff_h
[
i
];
delete
[]
roff_h
;
}
delete
[]
roff_h
;
return
roff_idx
;
return
roff_idx
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
}
}
torch_harmonics/csrc/disco/disco_interface.cu
View file @
373f9b0b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -31,9 +31,8 @@
...
@@ -31,9 +31,8 @@
#include "disco.h"
#include "disco.h"
#include "disco_cuda.cuh"
#include "disco_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment