Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c46b6925
Commit
c46b6925
authored
Jun 18, 2025
by
Max Rietmann
Browse files
Applied new formatting
parent
1ea5c4ca
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
771 additions
and
759 deletions
+771
-759
.clang-format
.clang-format
+3
-3
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+246
-245
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+155
-155
torch_harmonics/csrc/attention/attention_interface.cu
torch_harmonics/csrc/attention/attention_interface.cu
+2
-2
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+3
-3
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+184
-179
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+176
-170
torch_harmonics/csrc/disco/disco_interface.cu
torch_harmonics/csrc/disco/disco_interface.cu
+2
-2
No files found.
.clang-format
View file @
c46b6925
---
---
BasedOnStyle: Webkit
BasedOnStyle: Webkit
IndentWidth:
2
IndentWidth:
4
AccessModifierOffset: -2
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignAfterOpenBracket: Align
AlignTrailingComments: true
AlignTrailingComments: true
...
@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
...
@@ -13,8 +13,8 @@ BreakBeforeTernaryOperators: false
BreakConstructorInitializers: AfterColon
BreakConstructorInitializers: AfterColon
ColumnLimit: 120
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth:
2
ConstructorInitializerIndentWidth:
4
ContinuationIndentWidth:
2
ContinuationIndentWidth:
4
Cpp11BracedListStyle: true
Cpp11BracedListStyle: true
FixNamespaceComments: true
FixNamespaceComments: true
NamespaceIndentation: All
NamespaceIndentation: All
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
c46b6925
...
@@ -51,17 +51,17 @@
...
@@ -51,17 +51,17 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call)
\
{ \
{ \
cudaError_t err = call;
\
cudaError_t err = call; \
if (cudaSuccess != err) {
\
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err));
\
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE);
\
exit(EXIT_FAILURE); \
}
\
}
\
}
}
#endif
#endif
#include <iostream>
#include <iostream>
...
@@ -70,41 +70,42 @@
...
@@ -70,41 +70,42 @@
class
ScopeTimer
class
ScopeTimer
{
{
public:
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
{
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
}
{
}
~
ScopeTimer
()
{
~
ScopeTimer
()
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
{
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
}
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
}
private:
std
::
string
label_
;
private:
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
std
::
string
label_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
};
static
__device__
float
__warp_sum
(
float
val
)
static
__device__
float
__warp_sum
(
float
val
)
{
{
#pragma unroll
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
static
__device__
float
__warp_sum_cub
(
float
val
)
{
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
return
sum
;
return
sum
;
}
}
// This kernel computes the backward pass for the S2 attention mechanism, using
// This kernel computes the backward pass for the S2 attention mechanism, using
...
@@ -113,107 +114,107 @@ static __device__ float __warp_sum_cub(float val)
...
@@ -113,107 +114,107 @@ static __device__ float __warp_sum_cub(float val)
// memory access.
// memory access.
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
{
extern
__shared__
float
sh
[];
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
return
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
return
;
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
// Zero shared memory
// Zero shared memory
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
sh_alpha_k
[
chan
]
=
0.0
f
;
sh_alpha_vw
[
chan
]
=
0.0
f
;
sh_alpha_kvw
[
chan
]
=
0.0
f
;
sh_dy
[
chan
]
=
dy
[
batchId
][
chan
][
ho
][
wo
];
sh_qy
[
chan
]
=
qy
[
batchId
][
chan
][
ho
][
wo
];
}
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
float
integral
=
0.0
f
;
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
0.0
f
;
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_vw
[
chan
]
=
0.0
f
;
sh_alpha_kvw
[
chan
]
=
0.0
f
;
sh_dy
[
chan
]
=
dy
[
batchId
][
chan
][
ho
][
wo
];
sh_qy
[
chan
]
=
qy
[
batchId
][
chan
][
ho
][
wo
];
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
float
alpha_sum
=
0.0
f
;
gdotv
=
__warp_sum_cub
(
gdotv
);
float
qdotk_max
=
-
FLT_MAX
;
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
float
integral
=
0.0
f
;
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
__syncthreads
();
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
const
int
rlen
=
rend
-
rbeg
;
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
sh_alpha_vw
[
chan
]
=
sh_alpha_vw
[
chan
]
*
max_correction
+
alpha_inz
*
gdotv
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
sh_alpha_kvw
[
chan
]
=
sh_alpha_kvw
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
*
gdotv
;
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
=
sh_alpha_vw
[
chan
]
*
max_correction
+
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
=
sh_alpha_kvw
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
*
gdotv
;
}
qdotk_max
=
qdotk_max_tmp
;
}
}
qdotk_max
=
qdotk_max_tmp
;
}
integral
/=
alpha_sum
;
integral
/=
alpha_sum
;
// Write dydq
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
// Third pass: accumulate gradients for k and v
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
dydq
[
batchId
][
chan
][
ho
][
wo
]
gdotv
+
=
sh_
dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
]
;
=
(
sh_
alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
)
;
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
// Third pass: accumulate gradients for k and v
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
float
qyval
=
qy
[
batchId
][
chan
][
ho
][
wo
];
const
int
hi
=
col
/
nlon_in
;
float
dyval
=
sh_dy
[
chan
];
const
int
wi
=
col
-
(
hi
*
nlon_in
);
atomicAdd
(
&
dydk
[
batchId
][
chan
][
hi
][
wip
],
qyval
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
atomicAdd
(
&
dydv
[
batchId
][
chan
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
dyval
);
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
qyval
=
qy
[
batchId
][
chan
][
ho
][
wo
];
float
dyval
=
sh_dy
[
chan
];
atomicAdd
(
&
dydk
[
batchId
][
chan
][
hi
][
wip
],
qyval
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
atomicAdd
(
&
dydv
[
batchId
][
chan
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
dyval
);
}
}
}
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
...
@@ -222,122 +223,122 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -222,122 +223,122 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
dy
);
CHECK_CUDA_TENSOR
(
dy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
// Transpose to [batch, ho, wo, channel]
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
kxP
=
kx
;
kxP
=
kx
;
}
}
auto
vxP
=
at
::
Tensor
();
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
vxP
=
vx
;
vxP
=
vx
;
}
}
auto
qyP
=
at
::
Tensor
();
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
qyP
=
qy
;
qyP
=
qy
;
}
}
auto
dyP
=
at
::
Tensor
();
auto
dyP
=
at
::
Tensor
();
if
(
!
dy_channel_first
)
{
if
(
!
dy_channel_first
)
{
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
dyP
=
dy
;
dyP
=
dy
;
}
}
// cudaDeviceSynchronize();
// cudaDeviceSynchronize();
// delete permute_timer;
// delete permute_timer;
nvtxRangePop
();
nvtxRangePop
();
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
auto
dydk
=
torch
::
zeros_like
(
qyP
);
auto
dydk
=
torch
::
zeros_like
(
qyP
);
auto
dydv
=
torch
::
zeros_like
(
qyP
);
auto
dydv
=
torch
::
zeros_like
(
qyP
);
auto
dydq
=
torch
::
zeros_like
(
qyP
);
auto
dydq
=
torch
::
zeros_like
(
qyP
);
// print strdie of dydkP, dydvP, dydqP
// print strdie of dydkP, dydvP, dydqP
nvtxRangePop
();
nvtxRangePop
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf
(
"s2_attention_bwd_kernel execution time: %f ms
\n
"
,
milliseconds
);
printf
(
"s2_attention_bwd_kernel execution time: %f ms
\n
"
,
milliseconds
);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// Permute outputs back to memory layout given by input. if input had channels
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// printf("dydk strides:[");
// for(auto& stride : dydk.strides()) {
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// printf("%ld,", stride);
// }
// }
// printf("]\n");
// printf("]\n");
// cudaDeviceSynchronize();
// cudaDeviceSynchronize();
// delete permute_output_timer;
// delete permute_output_timer;
// nvtxRangePop();
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
}
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
c46b6925
...
@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
...
@@ -45,125 +45,125 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define NNZ_TRESH (32)
#define NNZ_TRESH (32)
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call)
\
{
\
{ \
cudaError_t err = call;
\
cudaError_t err = call; \
if (cudaSuccess != err) {
\
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err));
\
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE);
\
exit(EXIT_FAILURE); \
}
\
}
\
}
}
#define CHECK_ERROR(errorMessage) \
#define CHECK_ERROR(errorMessage) \
{
\
{
\
cudaError_t err = cudaGetLastError();
\
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) {
\
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__,
\
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", errorMessage, __FILE__, __LINE__, \
cudaGetErrorString(err));
\
cudaGetErrorString(err)); \
exit(EXIT_FAILURE);
\
exit(EXIT_FAILURE); \
}
\
}
\
}
}
static
__device__
float
__warp_sum
(
float
val
)
static
__device__
float
__warp_sum
(
float
val
)
{
{
#pragma unroll
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
static
__device__
float
__warp_sum_cub
(
float
val
)
{
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
return
sum
;
return
sum
;
}
}
// one warp per (ho,wo)
// one warp per (ho,wo)
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
y
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
{
extern
__shared__
float
sh
[];
extern
__shared__
float
sh
[];
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
float
*
shy
=
sh
+
threadIdx
.
y
*
num_channels
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
{
return
;
}
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
#if 0
#if 0
// useless read, y is always zeroed before kernel is called
// useless read, y is always zeroed before kernel is called
shy[chan] = y[batchId][chan][ho][wo];
shy[chan] = y[batchId][chan][ho][wo];
#else
#else
shy
[
chan
]
=
0
;
shy
[
chan
]
=
0
;
#endif
#endif
}
}
float
alpha_sum
=
0.0
f
;
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
float
qdotk_max
=
-
FLT_MAX
;
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk
=
__warp_sum_cub
(
qdotk
);
float
qdotk_max_tmp
;
float
qdotk_max_tmp
;
float
alpha
;
float
alpha
;
float
exp_save
;
float
exp_save
;
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
alpha
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
exp_save
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
alpha_sum
=
alpha
+
alpha_sum
*
exp_save
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
shy
[
chan
]
=
shy
[
chan
]
*
exp_save
+
vx
[
batchId
][
chan
][
hi
][
wip
]
*
alpha
;
}
qdotk_max
=
qdotk_max_tmp
;
}
}
qdotk_max
=
qdotk_max_tmp
;
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
y
[
batchId
][
chan
][
ho
][
wo
]
=
shy
[
chan
]
/
alpha_sum
;
}
return
;
return
;
}
}
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
torch
::
Tensor
s2_attention_fwd_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
quad_weights
,
...
@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
...
@@ -171,85 +171,85 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy,
int
nlon_out
)
int
nlon_out
)
{
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
qy
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
// TODO: check sizes
// TODO: check sizes
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
size_t
uo_num_channels
=
kx
.
size
(
1
);
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
const
int
batch_size
=
kx
.
size
(
0
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
// transpose inputs so that channels are in the last dimension, allowing for
// transpose inputs so that channels are in the last dimension, allowing for
// coalesced memory access
// coalesced memory access
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_fwd_kernel_mbT permute inputs"
);
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
kxP
=
kx
;
kxP
=
kx
;
}
}
auto
vxP
=
at
::
Tensor
();
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
vxP
=
vx
;
vxP
=
vx
;
}
}
auto
qyP
=
at
::
Tensor
();
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
}
else
{
qyP
=
qy
;
qyP
=
qy
;
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
nvtxRangePop
();
nvtxRangePop
();
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
torch
::
Tensor
y
=
torch
::
empty_like
(
qy
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
block
.
y
;
cudaEvent_t
start
,
stop
;
cudaEvent_t
start
,
stop
;
float
milliseconds
=
0
;
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
s2_attention_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
y
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
// printf("s2_attention_kernel_fwd execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
// match output layout to input layout
// match output layout to input layout
if
(
!
q_channel_first
)
y
=
y
.
contiguous
();
if
(
!
q_channel_first
)
y
=
y
.
contiguous
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
return
y
;
return
y
;
}
}
torch_harmonics/csrc/attention/attention_interface.cu
View file @
c46b6925
...
@@ -33,6 +33,6 @@
...
@@ -33,6 +33,6 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"forward"
,
&
s2_attention_fwd_cuda
,
"(Local) Attention on S2"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
m
.
def
(
"backward_dkvq"
,
&
s2_attention_bwd_dkvq_cuda
,
"(Local) Attention gradient on S2 (gradient for k,v,&q)"
);
}
}
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
c46b6925
...
@@ -37,10 +37,10 @@
...
@@ -37,10 +37,10 @@
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) \
#define CHECK_CUDA_INPUT_TENSOR(x) \
CHECK_CUDA_TENSOR(x);
\
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define MIN_THREADS (64)
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
#define ELXTH_MAX (32)
...
...
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
c46b6925
...
@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
...
@@ -38,122 +38,122 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
inp
+=
bidy
*
K
*
Hi
*
Wi
+
ker
*
Hi
*
Wi
+
row
*
Wi
;
out
+=
bidy
*
Ho
*
Wo
;
out
+=
bidy
*
Ho
*
Wo
;
// align to larger supported fp type
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
REAL_T
(
*
__sh
)[
BDIM_X
*
ELXTH
*
2
]
=
reinterpret_cast
<
REAL_T
(
*
)[
BDIM_X
*
ELXTH
*
2
]
>
(
__sh_ptr
);
// copy current inp row in regs
// copy current inp row in regs
REAL_T
__reg
[
ELXTH
];
REAL_T
__reg
[
ELXTH
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
__reg
[
i
]
=
(
i
*
BDIM_X
+
tid
<
Wi
)
?
inp
[
i
*
BDIM_X
+
tid
]
:
REAL_T
(
0
);
}
// reset shared row up to Wo+2, remaining
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// will be written to but never copied to
// global mem
// global mem
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
for
(
int
j
=
0
;
j
<
2
*
BDIM_X
*
ELXTH
;
j
+=
BDIM_X
)
{
__sh
[
i
][
j
+
tid
]
=
0
;
}
}
}
__syncthreads
();
__syncthreads
();
int
col_prev
=
cols
[
soff
];
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wo
;
int
h_prev
=
col_prev
/
Wo
;
int
w_prev
=
col_prev
%
Wo
;
int
w_prev
=
col_prev
%
Wo
;
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// leading to a new row of inp then copy it
// to shmem;
// to shmem;
// we read a col that points to a new output
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
// row if (col / Wo) > (col_prev / Wo)
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wo
)
{
__syncthreads
();
__syncthreads
();
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
__sh
[
i
][
Wi
+
j
]
=
0
;
}
}
}
}
__syncthreads
();
__syncthreads
();
col_prev
=
col
;
col_prev
=
col
;
h_prev
=
col
/
Wo
;
h_prev
=
col
/
Wo
;
w_prev
=
col
%
Wo
;
w_prev
=
col
%
Wo
;
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_mod_ps
=
w
%
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
const
int
w_div_ps
=
w
/
pscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
__sh
[
w_mod_ps
][
w_div_ps
+
pp
]
+=
val
*
__reg
[
i
];
}
}
// to avoid race conditions on __sh[]
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
// among consecutive iterations along nz
__syncthreads
();
}
__syncthreads
();
__syncthreads
();
}
__syncthreads
();
// write last row
// write last row
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
i
=
0
;
i
<
pscale
;
i
++
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
for
(
int
j
=
tid
;
j
<
Wi
;
j
+=
BDIM_X
)
{
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
const
REAL_T
v
=
__sh
[
i
][
j
]
+
__sh
[
i
][
Wi
+
j
];
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
atomicAdd
(
&
out
[
h_prev
*
Wo
+
j
*
pscale
+
i
],
v
);
}
}
}
}
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
{
if
constexpr
(
PSCALE
!=
0
)
{
if
constexpr
(
PSCALE
!=
0
)
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
PSCALE
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
else
{
}
else
{
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
disco_bwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
}
}
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
...
@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
...
@@ -162,113 +162,118 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t
stream
)
cudaStream_t
stream
)
{
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
if
(
NTH
*
ELXTH
>=
Wi
)
{
dim3
grid
(
nrows
,
BC
);
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wo
/
Wi
;
const
int
pscale
=
Wo
/
Wi
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
2
*
(
NTH
*
ELXTH
)
*
pscale
);
switch
(
pscale
)
{
switch
(
pscale
)
{
case
1
:
case
1
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
>
disco_bwd_blk_k
<
NTH
,
ELXTH
,
1
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
case
2
:
case
2
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
>
disco_bwd_blk_k
<
NTH
,
ELXTH
,
2
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
case
3
:
case
3
:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
>
disco_bwd_blk_k
<
NTH
,
ELXTH
,
3
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
break
;
break
;
default:
default:
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
>
disco_bwd_blk_k
<
NTH
,
ELXTH
,
0
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
}
}
}
else
{
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
stream
);
out_d
,
stream
);
}
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
disco_cuda_bwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
ker_idx
);
CHECK_CUDA_INPUT_TENSOR
(
ker_idx
);
CHECK_CUDA_INPUT_TENSOR
(
row_idx
);
CHECK_CUDA_INPUT_TENSOR
(
row_idx
);
CHECK_CUDA_INPUT_TENSOR
(
col_idx
);
CHECK_CUDA_INPUT_TENSOR
(
col_idx
);
CHECK_CUDA_INPUT_TENSOR
(
val
);
CHECK_CUDA_INPUT_TENSOR
(
val
);
// extract some shapes
// extract some shapes
int64_t
B
=
inp
.
size
(
0
);
int64_t
B
=
inp
.
size
(
0
);
int64_t
C
=
inp
.
size
(
1
);
int64_t
C
=
inp
.
size
(
1
);
int64_t
BC
=
B
*
C
;
int64_t
BC
=
B
*
C
;
int64_t
Hi
=
inp
.
size
(
3
);
int64_t
Hi
=
inp
.
size
(
3
);
int64_t
Wi
=
inp
.
size
(
4
);
int64_t
Wi
=
inp
.
size
(
4
);
int64_t
nrows
=
roff_idx
.
size
(
0
)
-
1
;
int64_t
nrows
=
roff_idx
.
size
(
0
)
-
1
;
// allocate output
// allocate output
int64_t
out_dims
[]
=
{
B
,
C
,
Ho
,
Wo
};
int64_t
out_dims
[]
=
{
B
,
C
,
Ho
,
Wo
};
auto
options
=
torch
::
TensorOptions
().
device
(
inp
.
device
()).
dtype
(
inp
.
dtype
());
auto
options
=
torch
::
TensorOptions
().
device
(
inp
.
device
()).
dtype
(
inp
.
dtype
());
torch
::
Tensor
out
=
torch
::
zeros
(
out_dims
,
options
);
torch
::
Tensor
out
=
torch
::
zeros
(
out_dims
,
options
);
// get stream
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
}));
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
}));
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
}));
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
}));
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}));
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}));
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}));
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}));
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_backward_cuda"
,
([
&
]
{
}));
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}
else
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
1024
*
ELXTH_MAX
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
exit
(
EXIT_FAILURE
);
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}
}));
}
else
{
return
out
;
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
}
return
out
;
}
}
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
...
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
c46b6925
...
@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
...
@@ -38,123 +38,124 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidx
=
blockIdx
.
x
;
// gloabl row
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
const
int64_t
bidy
=
blockIdx
.
y
;
// bc
int64_t
soff
=
roff
[
bidx
];
int64_t
soff
=
roff
[
bidx
];
int64_t
eoff
=
roff
[
bidx
+
1
];
int64_t
eoff
=
roff
[
bidx
+
1
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
ker
=
kers
[
soff
];
const
int64_t
row
=
rows
[
soff
];
const
int64_t
row
=
rows
[
soff
];
inp
+=
bidy
*
Hi
*
Wi
;
inp
+=
bidy
*
Hi
*
Wi
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
out
+=
bidy
*
K
*
Ho
*
Wo
+
ker
*
Ho
*
Wo
+
row
*
Wo
;
REAL_T
__reg
[
ELXTH
]
=
{
0
};
REAL_T
__reg
[
ELXTH
]
=
{
0
};
// align to larger supported fp type
// align to larger supported fp type
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
extern
__shared__
__align__
(
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
sizeof
(
double
))
unsigned
char
__sh_ptr
[];
// REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T
*
__sh
=
reinterpret_cast
<
REAL_T
*>
(
__sh_ptr
);
int
col_prev
=
cols
[
soff
];
int
col_prev
=
cols
[
soff
];
int
h_prev
=
col_prev
/
Wi
;
int
h_prev
=
col_prev
/
Wi
;
int
w_prev
=
col_prev
%
Wi
;
int
w_prev
=
col_prev
%
Wi
;
// copy current inp row in shmem
// copy current inp row in shmem
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads
();
__syncthreads
();
// loops along the colums of CTA's row
// loops along the colums of CTA's row
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
for
(
int64_t
nz
=
soff
;
nz
<
eoff
;
nz
++
)
{
const
int
col
=
cols
[
nz
];
const
int
col
=
cols
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
const
REAL_T
val
=
vals
[
nz
];
// if we are processing a nz with a col value
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// leading to a new row of inp then copy it
// to shmem;
// to shmem;
// checks whether (h_prev < h) with:
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
// (col >= col_prev - (col_prev % Wi) + Wi)
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
if
(
col
>=
col_prev
-
w_prev
+
Wi
)
{
col_prev
=
col
;
col_prev
=
col
;
h_prev
=
col
/
Wi
;
h_prev
=
col
/
Wi
;
w_prev
=
col
%
Wi
;
w_prev
=
col
%
Wi
;
__syncthreads
();
__syncthreads
();
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
for
(
int
i
=
tid
;
i
<
Wi
;
i
+=
BDIM_X
)
{
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
const
REAL_T
v
=
inp
[
h_prev
*
Wi
+
i
];
__sh
[
i
]
=
v
;
__sh
[
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
__sh
[
Wi
+
i
]
=
v
;
}
}
__syncthreads
();
__syncthreads
();
}
}
const
int
w
=
w_prev
+
(
col
-
col_prev
);
const
int
w
=
w_prev
+
(
col
-
col_prev
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
// original lines:
// original lines:
//
//
// if (pp >= Wo) break;
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
// const int wpp = (w + pscale*pp) % Wi;
//
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
// replicate the current inp row and avoid the costly mod
//
//
// also, to avoid the conditional, sh can be extended to
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
// cover the maximum location accessed during this loop
//
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
//
//
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM
// = 2*Wi + ppscale*NUM_REM
//
//
// with NUM_REM = BDIM_X*ELXTH - Wo
// with NUM_REM = BDIM_X*ELXTH - Wo
const
int
wpp
=
w
+
pscale
*
pp
;
const
int
wpp
=
w
+
pscale
*
pp
;
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
__reg
[
i
]
+=
val
*
__sh
[
wpp
];
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ELXTH
;
i
++
)
{
const
int
pp
=
i
*
BDIM_X
+
tid
;
const
int
pp
=
i
*
BDIM_X
+
tid
;
if
(
pp
>=
Wo
)
break
;
if
(
pp
>=
Wo
)
break
;
out
[
pp
]
=
__reg
[
i
];
out
[
pp
]
=
__reg
[
i
];
}
}
return
;
return
;
}
}
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
{
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
return
;
return
;
}
}
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
NTH
,
int
ELXTH
,
typename
REAL_T
>
...
@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
...
@@ -163,97 +164,102 @@ static void launch_kernel(int BC, int Hi, int Wi, int K, int Ho, int Wo, int64_t
cudaStream_t
stream
)
cudaStream_t
stream
)
{
{
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
static_assert
(
sizeof
(
REAL_T
)
==
2
||
sizeof
(
REAL_T
)
==
4
||
sizeof
(
REAL_T
)
==
8
);
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
constexpr
(
ELXTH
<=
ELXTH_MAX
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
if
(
NTH
*
ELXTH
>=
Wo
)
{
dim3
grid
(
nrows
,
BC
);
dim3
grid
(
nrows
,
BC
);
const
int
pscale
=
Wi
/
Wo
;
const
int
pscale
=
Wi
/
Wo
;
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
size_t
shmem
=
sizeof
(
*
out_d
)
*
(
Wi
*
2
+
pscale
*
(
NTH
*
ELXTH
-
Wo
));
disco_fwd_blk_k
<
NTH
,
ELXTH
>
disco_fwd_blk_k
<
NTH
,
ELXTH
><<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
<<<
grid
,
NTH
,
shmem
,
stream
>>>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
);
col_d
,
val_d
,
inp_d
,
out_d
);
}
else
{
}
else
{
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
out_d
,
launch_kernel
<
NTH
,
ELXTH
+
1
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_d
,
ker_d
,
row_d
,
col_d
,
val_d
,
inp_d
,
stream
);
out_d
,
stream
);
}
}
}
}
return
;
return
;
}
}
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
disco_cuda_fwd
(
torch
::
Tensor
inp
,
torch
::
Tensor
roff_idx
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
,
int64_t
K
,
int64_t
Ho
,
int64_t
Wo
)
{
{
// some sanity checks
// some sanity checks
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
inp
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
roff_idx
);
CHECK_CUDA_INPUT_TENSOR
(
ker_idx
);
CHECK_CUDA_INPUT_TENSOR
(
ker_idx
);
CHECK_CUDA_INPUT_TENSOR
(
row_idx
);
CHECK_CUDA_INPUT_TENSOR
(
row_idx
);
CHECK_CUDA_INPUT_TENSOR
(
col_idx
);
CHECK_CUDA_INPUT_TENSOR
(
col_idx
);
CHECK_CUDA_INPUT_TENSOR
(
val
);
CHECK_CUDA_INPUT_TENSOR
(
val
);
// extract some shapes
// extract some shapes
int64_t
B
=
inp
.
size
(
0
);
int64_t
B
=
inp
.
size
(
0
);
int64_t
C
=
inp
.
size
(
1
);
int64_t
C
=
inp
.
size
(
1
);
int64_t
BC
=
B
*
C
;
int64_t
BC
=
B
*
C
;
int64_t
Hi
=
inp
.
size
(
2
);
int64_t
Hi
=
inp
.
size
(
2
);
int64_t
Wi
=
inp
.
size
(
3
);
int64_t
Wi
=
inp
.
size
(
3
);
int64_t
nrows
=
roff_idx
.
size
(
0
)
-
1
;
int64_t
nrows
=
roff_idx
.
size
(
0
)
-
1
;
// allocate output
// allocate output
int64_t
out_dims
[]
=
{
B
,
C
,
K
,
Ho
,
Wo
};
int64_t
out_dims
[]
=
{
B
,
C
,
K
,
Ho
,
Wo
};
auto
options
=
torch
::
TensorOptions
().
device
(
inp
.
device
()).
dtype
(
inp
.
dtype
());
auto
options
=
torch
::
TensorOptions
().
device
(
inp
.
device
()).
dtype
(
inp
.
dtype
());
torch
::
Tensor
out
=
torch
::
zeros
(
out_dims
,
options
);
torch
::
Tensor
out
=
torch
::
zeros
(
out_dims
,
options
);
// get stream
// get stream
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// assert
// assert
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
static_assert
(
0
==
(
ELXTH_MAX
%
2
));
// pick the correct launch config
// pick the correct launch config
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
if
(
Wo
<=
64
*
ELXTH_MAX
)
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
launch_kernel
<
64
,
1
,
scalar_t
>
(
launch_kernel
<
64
,
1
,
scalar_t
>
(
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
}));
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
}));
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
}
else
if
(
Wo
<=
128
*
ELXTH_MAX
)
{
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
launch_kernel
<
128
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
}));
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
}));
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}
else
if
(
Wo
<=
256
*
ELXTH_MAX
)
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
launch_kernel
<
256
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}));
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}));
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
512
*
ELXTH_MAX
)
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
launch_kernel
<
512
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}));
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
}));
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
col_idx
.
data_ptr
<
int64_t
>
(),
}
else
if
(
Wo
<=
1024
*
ELXTH_MAX
)
{
val
.
data_ptr
<
scalar_t
>
(),
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
AT_DISPATCH_FLOATING_TYPES
(
inp
.
scalar_type
(),
"disco_forward_cuda"
,
([
&
]
{
}));
launch_kernel
<
1024
,
(
ELXTH_MAX
/
2
)
+
1
,
scalar_t
>
(
}
else
{
BC
,
Hi
,
Wi
,
K
,
Ho
,
Wo
,
nrows
,
roff_idx
.
data_ptr
<
int64_t
>
(),
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
ker_idx
.
data_ptr
<
int64_t
>
(),
row_idx
.
data_ptr
<
int64_t
>
(),
1024
*
ELXTH_MAX
);
col_idx
.
data_ptr
<
int64_t
>
(),
val
.
data_ptr
<
scalar_t
>
(),
exit
(
EXIT_FAILURE
);
inp
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
stream
);
}
}));
}
else
{
return
out
;
fprintf
(
stderr
,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d
\n
"
,
__FILE__
,
__LINE__
,
Wo
,
1024
*
ELXTH_MAX
);
exit
(
EXIT_FAILURE
);
}
return
out
;
}
}
torch_harmonics/csrc/disco/disco_interface.cu
View file @
c46b6925
...
@@ -33,6 +33,6 @@
...
@@ -33,6 +33,6 @@
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"forward"
,
&
disco_cuda_fwd
,
"DISCO forward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
m
.
def
(
"backward"
,
&
disco_cuda_bwd
,
"DISCO backward (CUDA)"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment