Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
6ac50e26
Commit
6ac50e26
authored
Jul 16, 2025
by
Thorsten Kurth
Browse files
removing commented code
parent
45fc2a46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
139 deletions
+0
-139
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+0
-89
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+0
-50
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
6ac50e26
...
@@ -899,95 +899,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -899,95 +899,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
quad_weights
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_col_idx
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
CHECK_CUDA_TENSOR
(
psi_row_off
);
// #if 0
// // extract dtype
// auto kx_type = kx.dtype();
// auto vx_type = vx.dtype();
// auto qy_type = qy.dtype();
// auto dy_type = dy.dtype();
// // exract memory format
// auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// // convert to channels-last
// auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// // create output arrays
// auto dydk = torch::zeros_like(qyP);
// auto dydv = torch::zeros_like(qyP);
// auto dydq = torch::zeros_like(qyP);
// size_t uo_num_channels = kx.size(1);
// const int batch_size = kx.size(0);
// dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
// dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
// size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
// cudaEvent_t start, stop;
// float milliseconds = 0;
// CHECK_CUDA(cudaEventCreate(&start));
// CHECK_CUDA(cudaEventCreate(&stop));
// CHECK_CUDA(cudaEventRecord(start, stream));
// s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
// uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
// CHECK_CUDA(cudaEventRecord(stop, stream));
// CHECK_CUDA(cudaEventSynchronize(stop));
// CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// CHECK_CUDA(cudaEventDestroy(start));
// CHECK_CUDA(cudaEventDestroy(stop));
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// // Permute outputs back to memory layout given by input. if input had channels
// // first, leave it in that layout, otherwise permute layout back to [batch,
// // channel, ho, wo]
// // convert back to original dtype
// dydk = dydk.to(kx_type);
// dydv = dydv.to(vx_type);
// dydq = dydq.to(qy_type);
// // permute back to original layout
// if (!kx_is_channels_last) {
// dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydk = dydk.to(kx_type);
// }
// if (!vx_is_channels_last) {
// dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydv = dydv.to(vx_type);
// }
// if (!qy_is_channels_last) {
// dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydq = dydq.to(qy_type);
// }
// return std::make_tuple(dydk, dydv, dydq);
// #else
const
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
size_t
uo_num_channels
=
kx
.
size
(
1
);
const
int
batch_size
=
kx
.
size
(
0
);
const
int
batch_size
=
kx
.
size
(
0
);
...
...
torch_harmonics/csrc/attention/attention_utils.cu
View file @
6ac50e26
...
@@ -121,17 +121,6 @@ static int getPtxver() {
...
@@ -121,17 +121,6 @@ static int getPtxver() {
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
)
{
at
::
Tensor
permute_4D_to0231
(
at
::
Tensor
src
)
{
//dim3 block;
//dim3 grid;
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(1), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(2)*src.size(0);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
2
),
src
.
size
(
3
),
src
.
size
(
1
)},
options
);
...
@@ -142,25 +131,11 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
...
@@ -142,25 +131,11 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_generic"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_generic"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
}));
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_sm100"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0231_k_tile_sm100"
,
([
&
]
{
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
launch_permute_to0231
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
}));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
}
...
@@ -169,17 +144,6 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
...
@@ -169,17 +144,6 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
)
{
at
::
Tensor
permute_4D_to0312
(
at
::
Tensor
src
)
{
//dim3 block;
//dim3 grid;
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(2), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(1)*src.size(0);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
auto
options
=
torch
::
TensorOptions
().
dtype
(
src
.
dtype
()).
device
(
src
.
device
());
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
torch
::
Tensor
dst
=
torch
::
empty
({
src
.
size
(
0
),
src
.
size
(
3
),
src
.
size
(
1
),
src
.
size
(
2
)},
options
);
...
@@ -187,28 +151,14 @@ at::Tensor permute_4D_to0312(at::Tensor src) {
...
@@ -187,28 +151,14 @@ at::Tensor permute_4D_to0312(at::Tensor src) {
// to be further specialized for additional archs, if necessary
// to be further specialized for additional archs, if necessary
if
(
ptxv
<
100
)
{
if
(
ptxv
<
100
)
{
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_generic"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_generic"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_GENERIC
,
scalar_t
>
(
src
,
dst
);
}));
}));
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_sm100"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"permute_to0312_k_tile_sm100"
,
([
&
]
{
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
launch_permute_to0312
<
TRANSP_WARPS_X_TILE_SM100
,
scalar_t
>
(
src
,
dst
);
}));
}));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment