Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e1338191
Commit
e1338191
authored
Jul 02, 2025
by
Thorsten Kurth
Browse files
using torch tools to change layout in bd pass
parent
49a61eee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
277 additions
and
257 deletions
+277
-257
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+277
-257
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
e1338191
...
...
@@ -51,53 +51,49 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
{ \
#define CHECK_CUDA(call) { \
cudaError_t err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
} \
}
}}
#endif
#include <iostream>
#include <chrono>
#include <string>
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
class
ScopeTimer
{
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
~
ScopeTimer
()
{
~
ScopeTimer
()
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
}
private:
private:
std
::
string
label_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
static
__device__
float
__warp_sum
(
float
val
)
{
static
__device__
float
__warp_sum
(
float
val
)
{
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
{
static
__device__
float
__warp_sum_cub
(
float
val
)
{
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
...
...
@@ -112,9 +108,14 @@ static __device__ float __warp_sum_cub(float val)
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
...
...
@@ -124,15 +125,14 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
extern
__shared__
float
sh
[];
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
// (optionally, could use more shared memory for other intermediates)
const
uint64_t
batchId
=
blockIdx
.
y
;
...
...
@@ -156,10 +156,24 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
// First pass: find qdotk_max
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
...
...
@@ -172,26 +186,22 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
alpha_sum
+=
alpha_inz
;
integral
+=
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
=
sh_alpha_vw
[
chan
]
*
max_correction
+
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
=
sh_alpha_kvw
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
*
gdotv
;
sh_alpha_k
[
chan
]
+
=
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
+
=
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
+
=
alpha_inz
*
kxval
*
gdotv
;
}
qdotk_max
=
qdotk_max_tmp
;
}
integral
/=
alpha_sum
;
// Write dydq
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
}
// Third pass: accumulate gradients for k and v
...
...
@@ -217,11 +227,16 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
vx
);
...
...
@@ -233,44 +248,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
auto
kxP
=
at
::
Tensor
();
if
(
!
k_channel_first
)
{
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
kxP
=
kx
;
}
auto
vxP
=
at
::
Tensor
();
if
(
!
v_channel_first
)
{
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
vxP
=
vx
;
}
auto
qyP
=
at
::
Tensor
();
if
(
!
q_channel_first
)
{
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
qyP
=
qy
;
}
auto
dyP
=
at
::
Tensor
();
if
(
!
dy_channel_first
)
{
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
dyP
=
dy
;
}
// extract dtype
auto
kx_type
=
kx
.
dtype
();
auto
vx_type
=
vx
.
dtype
();
auto
qy_type
=
qy
.
dtype
();
auto
dy_type
=
dy
.
dtype
();
// exract memory format
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
// convert to channels-last
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
// cudaDeviceSynchronize();
// delete permute_timer;
nvtxRangePop
();
...
...
@@ -285,8 +284,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
cudaEvent_t
start
,
stop
;
...
...
@@ -295,8 +294,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
...
...
@@ -312,10 +313,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
...
...
@@ -324,11 +323,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
// Permute outputs back to memory layout given by input. if input had channels
// first, leave it in that layout, otherwise permute layout back to [batch,
// channel, ho, wo]
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
// printf("dydk strides:[");
// convert back to original dtype
dydk
=
dydk
.
to
(
kx_type
);
dydv
=
dydv
.
to
(
vx_type
);
dydq
=
dydq
.
to
(
qy_type
);
// permute back to original layout
if
(
!
kx_is_channels_last
){
dydk
=
dydk
.
to
(
kx_type
,
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydk
=
dydk
.
to
(
kx_type
);
}
if
(
!
vx_is_channels_last
){
dydv
=
dydv
.
to
(
vx_type
,
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydv
=
dydv
.
to
(
vx_type
);
}
if
(
!
qy_is_channels_last
)
{
dydq
=
dydq
.
to
(
qy_type
,
at
::
MemoryFormat
::
Contiguous
);
}
else
{
dydq
=
dydq
.
to
(
qy_type
)
}
// printf("dydk strides: [");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
...
...
@@ -337,4 +355,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment