Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
4805b39c
Commit
4805b39c
authored
Jun 13, 2025
by
Thorsten Kurth
Committed by
Boris Bonev
Jun 17, 2025
Browse files
formatting
parent
5eaa7f79
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
470 additions
and
721 deletions
+470
-721
torch_harmonics/csrc/attention/attention.cuh
torch_harmonics/csrc/attention/attention.cuh
+7
-12
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+69
-80
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+73
-88
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+2
-2
torch_harmonics/csrc/disco/disco.h
torch_harmonics/csrc/disco/disco.h
+1
-1
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+10
-23
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+145
-242
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+121
-209
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+39
-60
torch_harmonics/csrc/disco/disco_interface.cu
torch_harmonics/csrc/disco/disco_interface.cu
+3
-4
No files found.
torch_harmonics/csrc/attention/attention.cuh
View file @
4805b39c
...
...
@@ -36,16 +36,11 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
);
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -51,28 +51,32 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#include <iostream>
#include <chrono>
#include <string>
class
ScopeTimer
{
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
~
ScopeTimer
()
{
~
ScopeTimer
()
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
...
...
@@ -83,20 +87,19 @@ private:
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
...
...
@@ -108,31 +111,27 @@ static __device__ float __warp_sum_cub(float val) {
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
...
...
@@ -156,7 +155,7 @@ __launch_bounds__(BDIM_X)
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// First pass: find qdotk_max
...
...
@@ -166,9 +165,7 @@ __launch_bounds__(BDIM_X)
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
}
...
...
@@ -201,7 +198,8 @@ __launch_bounds__(BDIM_X)
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
// Third pass: accumulate gradients for k and v
...
...
@@ -227,16 +225,11 @@ __launch_bounds__(BDIM_X)
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
...
...
@@ -257,7 +250,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
...
@@ -300,8 +293,8 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
...
...
@@ -310,20 +303,18 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
...
...
@@ -333,15 +324,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
...
...
@@ -352,6 +343,4 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -45,39 +45,42 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
#define CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
}}
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
#define CHECK_CUDA(call) \
{ \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
...
...
@@ -85,40 +88,33 @@ static __device__ float __warp_sum_cub(float val) {
return
sum
;
}
// one warp per (ho,wo)
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
#if 0
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
...
...
@@ -130,23 +126,22 @@ __launch_bounds__(BDIM_X)
float
qdotk_max
=
-
FLT_MAX
;
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
...
...
@@ -158,31 +153,23 @@ __launch_bounds__(BDIM_X)
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
}
qdotk_max
=
qdotk_max_tmp
;
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
return
;
}
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
...
...
@@ -206,7 +193,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
//Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
//
Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
...
...
@@ -232,10 +219,10 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
nvtxRangePop
();
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
...
...
@@ -243,15 +230,14 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_kernel
<
THREADS
>
<<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
s2_attention_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
...
...
@@ -267,4 +253,3 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
return
y
;
}
torch_harmonics/csrc/attention/attention_interface.cu
View file @
4805b39c
...
...
@@ -31,8 +31,8 @@
#include "attention.cuh"
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
}
torch_harmonics/csrc/disco/disco.h
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -36,32 +36,19 @@
#include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x)
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define MIN_THREADS (64)
#define ELXTH_MAX
(32)
#define ELXTH_MAX (32)
// forward kernel
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
// backward kernel
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
\ No newline at end of file
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
);
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -31,239 +31,175 @@
#include "disco.h"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__device__
void
disco_bwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
out
+=
bidy
*
Ho
*
Wo
;
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
out
+=
bidy
*
Ho
*
Wo
;
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
// copy current inp row in regs
REAL_T
__reg
[
ELXTH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// global mem
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
}
__syncthreads
();
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wo
;
int
w_prev
=
col_prev
%
Wo
;
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
__syncthreads
();
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
}
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
}
}
__syncthreads
();
col_prev
=
col
;
h_prev
=
col
/
Wo
;
w_prev
=
col
%
Wo
;
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
const
int
pp
=
i
*
BDIM_X
+
tid
;
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
}
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
__syncthreads
();
}
__syncthreads
();
// write last row
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
}
}
return
;
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
return
;
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wo
/
Wi
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
switch
(
pscale
)
{
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
const
int
pscale
=
Wo
/
Wi
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
switch
(
pscale
)
{
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
return
;
}
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
...
...
@@ -287,87 +223,54 @@ torch::Tensor disco_cuda_bwd(torch::Tensor inp,
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
}
return
out
;
}
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//}
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -31,101 +31,90 @@
#include "disco.h"
#include "disco_cuda.cuh"
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__device__
void
disco_fwd_d
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
Hi
*
Wi
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
inp
+=
bidy
*
Hi
*
Wi
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
REAL_T
__reg
[
ELXTH
]
=
{
0
};
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wi
;
int
w_prev
=
col_prev
%
Wi
;
// copy current inp row in shmem
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads
();
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
col_prev
=
col
;
h_prev
=
col
/
Wi
;
w_prev
=
col
%
Wi
;
__syncthreads
();
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
__syncthreads
();
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
#pragma unroll
const
int
w
=
w_prev
+
(
col
-
col_prev
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
// original lines:
//
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
//
// also, to avoid the conditional, sh can be extended to
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
...
...
@@ -135,113 +124,68 @@ __device__ void disco_fwd_d(const int Hi,
// = 2*Wi + ppscale*NUM_REM
//
// with NUM_REM = BDIM_X*ELXTH - Wo
const
int
wpp
=
w
+
pscale
*
pp
;
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
const
int
wpp
=
w
+
pscale
*
pp
;
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
if
(
pp
>=
Wo
)
break
;
out
[
pp
]
=
__reg
[
i
];
}
return
;
}
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
return
;
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
static
void
launch_kernel
(
int
BC
,
int
Hi
,
int
Wi
,
int
K
,
int
Ho
,
int
Wo
,
int64_t
nrows
,
int64_t
*
roff_d
,
int64_t
*
ker_d
,
int64_t
*
row_d
,
int64_t
*
col_d
,
REAL_T
*
val_d
,
REAL_T
*
inp_d
,
REAL_T
*
out_d
,
cudaStream_t
stream
)
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
>
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
stream
);
}
}
return
;
}
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
...
...
@@ -265,83 +209,51 @@ torch::Tensor disco_cuda_fwd(torch::Tensor inp,
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// pick the correct launch config
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}));
}
else
{
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
}
return
out
;
}
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -30,40 +30,30 @@
#include "disco.h"
template
<
typename
REAL_T
>
void
preprocess_psi_kernel
(
int64_t
nnz
,
int64_t
K
,
int64_t
Ho
,
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
{
template
<
typename
REAL_T
>
void
preprocess_psi_kernel
(
int64_t
nnz
,
int64_t
K
,
int64_t
Ho
,
int64_t
*
ker_h
,
int64_t
*
row_h
,
int64_t
*
col_h
,
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
{
int64_t
*
Koff
=
new
int64_t
[
K
];
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
int64_t
prev
=
Koff
[
0
];
Koff
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
int64_t
save
=
Koff
[
i
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
prev
=
save
;
}
int64_t
*
ker_sort
=
new
int64_t
[
nnz
];
int64_t
*
row_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
off
=
Koff
[
ker
]
++
;
...
...
@@ -73,31 +63,30 @@ void preprocess_psi_kernel(int64_t nnz,
col_sort
[
off
]
=
col_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
ker_h
[
i
]
=
ker_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
}
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
// compute rows offsets
nrows
=
1
;
roff_h
[
0
]
=
0
;
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
roff_h
[
nrows
++
]
=
i
;
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
exit
(
EXIT_FAILURE
);
}
}
...
...
@@ -106,50 +95,40 @@ void preprocess_psi_kernel(int64_t nnz,
return
;
}
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
//float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
return
roff_idx
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
}
torch_harmonics/csrc/disco/disco_interface.cu
View file @
4805b39c
...
...
@@ -2,7 +2,7 @@
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
...
...
@@ -31,9 +31,8 @@
#include "disco.h"
#include "disco_cuda.cuh"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment