Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e1338191
Commit
e1338191
authored
Jul 02, 2025
by
Thorsten Kurth
Browse files
using torch tools to change layout in bd pass
parent
49a61eee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
277 additions
and
257 deletions
+277
-257
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+277
-257
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
e1338191
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
//
//
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// SPDX-License-Identifier: BSD-3-Clause
//
//
// Redistribution and use in source and binary forms, with or without
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// modification, are permitted provided that the following conditions are met:
//
//
...
@@ -51,290 +51,310 @@
...
@@ -51,290 +51,310 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a,
b) (((a)
+
((b)-1))
/
(b))
#define DIV_UP(a,b) (((a)
+
((b)-1))
/
(b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call) { \
{ \
cudaError_t err = call; \
cudaError_t err = call; \
if( cudaSuccess != err) { \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
} \
}}
}
#endif
#endif
#include <iostream>
#include <iostream>
#include <chrono>
#include <chrono>
#include <string>
#include <string>
class
ScopeTimer
class
ScopeTimer
{
{
public:
public:
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
explicit
ScopeTimer
(
const
std
::
string
&
label
=
""
)
:
:
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
label_
(
label
),
start_
(
std
::
chrono
::
high_resolution_clock
::
now
())
{
}
~
ScopeTimer
()
~
ScopeTimer
()
{
{
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
end
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
auto
elapsed
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
start_
);
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
label_
<<
"Elapsed time: "
<<
elapsed
.
count
()
<<
" ms"
<<
std
::
endl
;
}
}
private:
private:
std
::
string
label_
;
std
::
string
label_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
std
::
chrono
::
high_resolution_clock
::
time_point
start_
;
};
};
static
__device__
float
__warp_sum
(
float
val
)
static
__device__
float
__warp_sum
(
float
val
)
{
{
#pragma unroll
#pragma unroll
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
for
(
int
i
=
WARP_SIZE
/
2
;
i
;
i
/=
2
)
{
return
val
;
val
+=
__shfl_xor_sync
(
FULL_MASK
,
val
,
i
);
}
return
val
;
}
}
// easier to understand version of manual shfl_xor_sync, performance appears similar
// easier to understand version of manual shfl_xor_sync, performance appears similar
static
__device__
float
__warp_sum_cub
(
float
val
)
static
__device__
float
__warp_sum_cub
(
float
val
)
{
{
// use cub to reduce within a warp
// use cub to reduce within a warp
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
__shared__
typename
cub
::
WarpReduce
<
float
>::
TempStorage
temp_storage
;
// 1. Compute sum (initially only in lane 0)
// 1. Compute sum (initially only in lane 0)
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
float
sum
=
cub
::
WarpReduce
<
float
>
(
temp_storage
).
Sum
(
val
);
// 2. Broadcast sum to all threads
// 2. Broadcast sum to all threads
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
sum
=
__shfl_sync
(
0xFFFFFFFF
,
sum
,
0
);
return
sum
;
return
sum
;
}
}
// This kernel computes the backward pass for the S2 attention mechanism, using
// This kernel computes the backward pass for the S2 attention mechanism, using
// shared memory as a cache and one warp per output point, warp-parallel over
// shared memory as a cache and one warp per output point, warp-parallel over
// channels, which should be layed out in the fastest dimension for coalesced
// channels, which should be layed out in the fastest dimension for coalesced
// memory access.
// memory access.
template
<
int
BDIM_X
>
template
<
int
BDIM_X
>
__global__
__launch_bounds__
(
BDIM_X
)
void
s2_attention_bwd_dkvq_kernel
(
__global__
int
num_channels
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
,
__launch_bounds__
(
BDIM_X
)
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
void
s2_attention_bwd_dkvq_kernel
(
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
int
num_channels
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
int
nlon_in
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
int
nlat_out
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
int
nlon_out
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
kx
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
vx
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
qy
,
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
const
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dy
,
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydk
,
{
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydv
,
torch
::
PackedTensorAccessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
dydq
,
extern
__shared__
float
sh
[];
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_col_idx
,
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
const
torch
::
PackedTensorAccessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
psi_row_offset
,
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
const
torch
::
PackedTensorAccessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
quad_weights
)
{
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
extern
__shared__
float
sh
[];
float
*
sh_qy
=
sh_dy
+
num_channels
;
float
*
sh_alpha_k
=
sh
+
threadIdx
.
y
*
num_channels
*
5
;
// (optionally, could use more shared memory for other intermediates)
float
*
sh_alpha_vw
=
sh_alpha_k
+
num_channels
;
float
*
sh_alpha_kvw
=
sh_alpha_vw
+
num_channels
;
const
uint64_t
batchId
=
blockIdx
.
y
;
float
*
sh_dy
=
sh_alpha_kvw
+
num_channels
;
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
float
*
sh_qy
=
sh_dy
+
num_channels
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
return
;
// (optionally, could use more shared memory for other intermediates)
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
uint64_t
batchId
=
blockIdx
.
y
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
const
uint64_t
wid
=
uint64_t
(
blockIdx
.
x
)
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
wid
>=
uint64_t
(
nlat_out
)
*
nlon_in
)
return
;
// Zero shared memory
const
int
tidx
=
threadIdx
.
x
;
const
int
ho
=
wid
/
nlon_out
;
const
int
wo
=
wid
-
(
ho
*
nlon_out
);
// Zero shared memory
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
sh_alpha_k
[
chan
]
=
0.0
f
;
sh_alpha_vw
[
chan
]
=
0.0
f
;
sh_alpha_kvw
[
chan
]
=
0.0
f
;
sh_dy
[
chan
]
=
dy
[
batchId
][
chan
][
ho
][
wo
];
sh_qy
[
chan
]
=
qy
[
batchId
][
chan
][
ho
][
wo
];
}
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
float
integral
=
0.0
f
;
__syncthreads
();
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
// First pass: find qdotk_max
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
sh_alpha_k
[
chan
]
=
0.0
f
;
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_vw
[
chan
]
=
0.0
f
;
sh_alpha_kvw
[
chan
]
=
0.0
f
;
sh_dy
[
chan
]
=
dy
[
batchId
][
chan
][
ho
][
wo
];
sh_qy
[
chan
]
=
qy
[
batchId
][
chan
][
ho
][
wo
];
}
}
float
alpha_sum
=
0.0
f
;
qdotk
=
__warp_sum_cub
(
qdotk
);
float
qdotk_max
=
-
FLT_MAX
;
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
float
integral
=
0.0
f
;
}
__syncthreads
();
// Second pass: accumulate alpha_sum, integral, and shared stats
const
int64_t
rbeg
=
psi_row_offset
[
ho
];
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
rlen
=
rend
-
rbeg
;
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
const
int
hi
=
col
/
nlon_in
;
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
const
int
wi
=
col
-
(
hi
*
nlon_in
);
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
=
sh_alpha_vw
[
chan
]
*
max_correction
+
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
=
sh_alpha_kvw
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
*
gdotv
;
}
qdotk_max
=
qdotk_max_tmp
;
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
integral
/=
alpha_sum
;
gdotv
=
__warp_sum_cub
(
gdotv
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
// Write dydq
alpha_sum
+=
alpha_inz
;
integral
+=
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
dydq
[
batchId
][
chan
][
ho
][
wo
]
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
sh_alpha_k
[
chan
]
+=
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
+=
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
+=
alpha_inz
*
kxval
*
gdotv
;
}
}
}
// Third pass: accumulate gradients for k and v
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
integral
/=
alpha_sum
;
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
// Write dydq
const
int
wi
=
col
-
(
hi
*
nlon_in
);
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
dydq
[
batchId
][
chan
][
ho
][
wo
]
=
(
sh_alpha_kvw
[
chan
]
*
alpha_sum
-
sh_alpha_vw
[
chan
]
*
sh_alpha_k
[
chan
])
/
(
alpha_sum
*
alpha_sum
);
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
}
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
// Third pass: accumulate gradients for k and v
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
}
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
qdotk
=
__warp_sum_cub
(
qdotk
);
const
int
hi
=
col
/
nlon_in
;
gdotv
=
__warp_sum_cub
(
gdotv
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
qdotk
=
0.0
f
,
gdotv
=
0.0
f
;
float
qyval
=
qy
[
batchId
][
chan
][
ho
][
wo
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
dyval
=
sh_dy
[
chan
];
qdotk
+=
qy
[
batchId
][
chan
][
ho
][
wo
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
atomicAdd
(
&
dydk
[
batchId
][
chan
][
hi
][
wip
],
qyval
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
gdotv
+=
sh_dy
[
chan
]
*
vx
[
batchId
][
chan
][
hi
][
wip
];
atomicAdd
(
&
dydv
[
batchId
][
chan
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
dyval
);
}
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
qyval
=
qy
[
batchId
][
chan
][
ho
][
wo
];
float
dyval
=
sh_dy
[
chan
];
atomicAdd
(
&
dydk
[
batchId
][
chan
][
hi
][
wip
],
qyval
*
(
alpha_inz
/
alpha_sum
)
*
(
gdotv
-
integral
));
atomicAdd
(
&
dydv
[
batchId
][
chan
][
hi
][
wip
],
(
alpha_inz
/
alpha_sum
)
*
dyval
);
}
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
at
::
Tensor
qy
,
at
::
Tensor
dy
,
at
::
Tensor
quad_weights
,
at
::
Tensor
psi_col_idx
,
at
::
Tensor
psi_row_off
,
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
s2_attention_bwd_dkvq_cuda
(
at
::
Tensor
kx
,
at
::
Tensor
vx
,
{
at
::
Tensor
qy
,
at
::
Tensor
dy
,
CHECK_CUDA_TENSOR
(
kx
);
at
::
Tensor
quad_weights
,
CHECK_CUDA_TENSOR
(
vx
);
at
::
Tensor
psi_col_idx
,
CHECK_CUDA_TENSOR
(
qy
);
at
::
Tensor
psi_row_off
,
CHECK_CUDA_TENSOR
(
quad_weights
);
int
nlon_in
,
int
nlat_out
,
int
nlon_out
)
{
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
kx
);
CHECK_CUDA_TENSOR
(
dy
);
CHECK_CUDA_TENSOR
(
vx
);
CHECK_CUDA_TENSOR
(
qy
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
auto
k_channel_first
=
kx
.
strides
()[
1
]
==
1
;
CHECK_CUDA_TENSOR
(
psi_row_off
);
auto
v_channel_first
=
vx
.
strides
()[
1
]
==
1
;
CHECK_CUDA_TENSOR
(
dy
);
auto
q_channel_first
=
qy
.
strides
()[
1
]
==
1
;
auto
dy_channel_first
=
dy
.
strides
()[
1
]
==
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
// Transpose to [batch, ho, wo, channel]
// Transpose to [batch, ho, wo, channel]
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT permute inputs"
);
// auto* permute_timer = new ScopeTimer("permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// Permute kx,vx,qy,dy to [batch, ho, wo, channel] in memory layout, but keep the original shape [batch, channel, ho, wo]
// extract dtype
auto
kxP
=
at
::
Tensor
();
auto
kx_type
=
kx
.
dtype
();
if
(
!
k_channel_first
)
{
auto
vx_type
=
vx
.
dtype
();
// printf("Permuting kx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
auto
qy_type
=
qy
.
dtype
();
kxP
=
kx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
auto
dy_type
=
dy
.
dtype
();
}
else
{
kxP
=
kx
;
// exract memory format
}
auto
kx_is_channels_last
=
kx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
auto
vxP
=
at
::
Tensor
();
auto
vx_is_channels_last
=
vx
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
if
(
!
v_channel_first
)
{
auto
qy_is_channels_last
=
qy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
// printf("Permuting vx from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
auto
dy_is_channels_last
=
dy
.
is_contiguous
(
at
::
MemoryFormat
::
Channels_last
);
vxP
=
vx
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
}
else
{
// convert to channels-last
vxP
=
vx
;
auto
kxP
=
kx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
}
auto
vxP
=
vx
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
auto
qyP
=
at
::
Tensor
();
auto
qyP
=
qy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
if
(
!
q_channel_first
)
{
auto
dyP
=
dy
.
to
(
torch
::
kFloat32
,
at
::
MemoryFormat
::
ChannelsLast
);
// printf("Permuting qy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
qyP
=
qy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
// cudaDeviceSynchronize();
}
else
{
// delete permute_timer;
qyP
=
qy
;
nvtxRangePop
();
}
auto
dyP
=
at
::
Tensor
();
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
if
(
!
dy_channel_first
)
{
auto
dydk
=
torch
::
zeros_like
(
qyP
);
// printf("Permuting dy from [batch, channel, ho, wo] to [batch, ho, wo, channel]\n");
auto
dydv
=
torch
::
zeros_like
(
qyP
);
dyP
=
dy
.
permute
({
0
,
2
,
3
,
1
}).
contiguous
().
permute
({
0
,
3
,
1
,
2
});
auto
dydq
=
torch
::
zeros_like
(
qyP
);
}
else
{
// print strdie of dydkP, dydvP, dydqP
dyP
=
dy
;
nvtxRangePop
();
}
// cudaDeviceSynchronize();
size_t
uo_num_channels
=
kx
.
size
(
1
);
// delete permute_timer;
const
int
batch_size
=
kx
.
size
(
0
);
nvtxRangePop
();
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
nvtxRangePush
(
"s2_attention_bwd_dkvq_kernel_mbT output allocation & zero"
);
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
auto
dydk
=
torch
::
zeros_like
(
qyP
);
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
auto
dydv
=
torch
::
zeros_like
(
qyP
);
auto
dydq
=
torch
::
zeros_like
(
qyP
);
cudaEvent_t
start
,
stop
;
// print strdie of dydkP, dydvP, dydqP
float
milliseconds
=
0
;
nvtxRangePop
();
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
size_t
uo_num_channels
=
kx
.
size
(
1
);
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
const
int
batch_size
=
kx
.
size
(
0
);
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
dim3
block
(
WARP_SIZE
,
THREADS
/
WARP_SIZE
);
grid
,
block
,
shared_size
,
stream
>>>
(
dim3
grid
(
DIV_UP
(
nlat_out
*
nlon_out
,
block
.
y
),
batch_size
);
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
size_t
shared_size
=
sizeof
(
float
)
*
uo_num_channels
*
5
*
block
.
y
;
// 4 arrays per warp
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
cudaEvent_t
start
,
stop
;
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
float
milliseconds
=
0
;
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventRecord
(
start
,
stream
));
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
s2_attention_bwd_dkvq_kernel
<
THREADS
><<<
grid
,
block
,
shared_size
,
stream
>>>
(
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
uo_num_channels
,
nlon_in
,
nlat_out
,
nlon_out
,
kxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
vxP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
qyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
dyP
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
dydk
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
dydv
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dydq
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
psi_col_idx
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
psi_row_off
.
packed_accessor64
<
int64_t
,
1
,
torch
::
RestrictPtrTraits
>
(),
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
quad_weights
.
packed_accessor32
<
float
,
1
,
torch
::
RestrictPtrTraits
>
());
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
C10_CUDA_KERNEL_LAUNCH_CHECK
();
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// Permute outputs back to memory layout given by input. if input had channels
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// first, leave it in that layout, otherwise permute layout back to [batch,
// s2_attention_bwd_kernel execution time: 50.724865 ms
// channel, ho, wo]
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
// convert back to original dtype
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
dydk
=
dydk
.
to
(
kx_type
);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
dydv
=
dydv
.
to
(
vx_type
);
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
dydq
=
dydq
.
to
(
qy_type
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// permute back to original layout
if
(
!
kx_is_channels_last
){
// Permute outputs back to memory layout given by input. if input had channels
dydk
=
dydk
.
to
(
kx_type
,
at
::
MemoryFormat
::
Contiguous
);
// first, leave it in that layout, otherwise permute layout back to [batch,
}
else
{
// channel, ho, wo]
dydk
=
dydk
.
to
(
kx_type
);
if
(
!
k_channel_first
)
dydk
=
dydk
.
contiguous
();
}
if
(
!
v_channel_first
)
dydv
=
dydv
.
contiguous
();
if
(
!
vx_is_channels_last
){
if
(
!
q_channel_first
)
dydq
=
dydq
.
contiguous
();
dydv
=
dydv
.
to
(
vx_type
,
at
::
MemoryFormat
::
Contiguous
);
}
else
{
// printf("dydk strides:[");
dydv
=
dydv
.
to
(
vx_type
);
// for(auto& stride : dydk.strides()) {
}
// printf("%ld,", stride);
if
(
!
qy_is_channels_last
)
{
// }
dydq
=
dydq
.
to
(
qy_type
,
at
::
MemoryFormat
::
Contiguous
);
// printf("]\n");
}
else
{
// cudaDeviceSynchronize();
dydq
=
dydq
.
to
(
qy_type
)
// delete permute_output_timer;
}
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
// printf("dydk strides: [");
// for(auto& stride : dydk.strides()) {
// printf("%ld,", stride);
// }
// printf("]\n");
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return
std
::
make_tuple
(
dydk
,
dydv
,
dydq
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment