Commit 6ac50e26 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

removing commented code

parent 45fc2a46
...@@ -899,95 +899,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -899,95 +899,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA_TENSOR(quad_weights); CHECK_CUDA_TENSOR(quad_weights);
CHECK_CUDA_TENSOR(psi_col_idx); CHECK_CUDA_TENSOR(psi_col_idx);
CHECK_CUDA_TENSOR(psi_row_off); CHECK_CUDA_TENSOR(psi_row_off);
// #if 0
// // extract dtype
// auto kx_type = kx.dtype();
// auto vx_type = vx.dtype();
// auto qy_type = qy.dtype();
// auto dy_type = dy.dtype();
// // exract memory format
// auto kx_is_channels_last = kx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto vx_is_channels_last = vx.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto qy_is_channels_last = qy.is_contiguous(at::MemoryFormat::ChannelsLast);
// auto dy_is_channels_last = dy.is_contiguous(at::MemoryFormat::ChannelsLast);
// // convert to channels-last
// auto kxP = kx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto vxP = vx.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto qyP = qy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// auto dyP = dy.to(torch::kFloat32).to(at::MemoryFormat::ChannelsLast);
// // create output arrays
// auto dydk = torch::zeros_like(qyP);
// auto dydv = torch::zeros_like(qyP);
// auto dydq = torch::zeros_like(qyP);
// size_t uo_num_channels = kx.size(1);
// const int batch_size = kx.size(0);
// dim3 block(WARP_SIZE, THREADS / WARP_SIZE);
// dim3 grid(DIV_UP(nlat_out * nlon_out, block.y), batch_size);
// size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
// cudaEvent_t start, stop;
// float milliseconds = 0;
// CHECK_CUDA(cudaEventCreate(&start));
// CHECK_CUDA(cudaEventCreate(&stop));
// CHECK_CUDA(cudaEventRecord(start, stream));
// s2_attention_bwd_dkvq_kernel<THREADS><<<grid, block, shared_size, stream>>>(
// uo_num_channels, nlon_in, nlat_out, nlon_out, kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydk.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydv.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dydq.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// psi_col_idx.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// psi_row_off.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
// quad_weights.packed_accessor32<float, 1, torch::RestrictPtrTraits>());
// CHECK_CUDA(cudaEventRecord(stop, stream));
// CHECK_CUDA(cudaEventSynchronize(stop));
// CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// // s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// CHECK_CUDA(cudaEventDestroy(start));
// CHECK_CUDA(cudaEventDestroy(stop));
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// // Permute outputs back to memory layout given by input. if input had channels
// // first, leave it in that layout, otherwise permute layout back to [batch,
// // channel, ho, wo]
// // convert back to original dtype
// dydk = dydk.to(kx_type);
// dydv = dydv.to(vx_type);
// dydq = dydq.to(qy_type);
// // permute back to original layout
// if (!kx_is_channels_last) {
// dydk = dydk.to(kx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydk = dydk.to(kx_type);
// }
// if (!vx_is_channels_last) {
// dydv = dydv.to(vx_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydv = dydv.to(vx_type);
// }
// if (!qy_is_channels_last) {
// dydq = dydq.to(qy_type).to(at::MemoryFormat::Contiguous);
// } else {
// dydq = dydq.to(qy_type);
// }
// return std::make_tuple(dydk, dydv, dydq);
// #else
const size_t uo_num_channels = kx.size(1); const size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0); const int batch_size = kx.size(0);
......
...@@ -121,17 +121,6 @@ static int getPtxver() { ...@@ -121,17 +121,6 @@ static int getPtxver() {
at::Tensor permute_4D_to0231(at::Tensor src) { at::Tensor permute_4D_to0231(at::Tensor src) {
//dim3 block;
//dim3 grid;
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(1), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(2)*src.size(0);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options);
...@@ -142,25 +131,11 @@ at::Tensor permute_4D_to0231(at::Tensor src) { ...@@ -142,25 +131,11 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst); launch_permute_to0231<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
})); }));
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_generic"); CHECK_ERROR("permute_to0231_k_tile_generic");
} else { } else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] {
launch_permute_to0231<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst); launch_permute_to0231<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
})); }));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(1),
// src.size(2),
// src.size(3),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0231_k_tile_sm100"); CHECK_ERROR("permute_to0231_k_tile_sm100");
} }
...@@ -169,17 +144,6 @@ at::Tensor permute_4D_to0231(at::Tensor src) { ...@@ -169,17 +144,6 @@ at::Tensor permute_4D_to0231(at::Tensor src) {
at::Tensor permute_4D_to0312(at::Tensor src) { at::Tensor permute_4D_to0312(at::Tensor src) {
//dim3 block;
//dim3 grid;
//block.x = WARP_SIZE;
//grid.x = DIV_UP(src.size(2), block.x);
//grid.y = DIV_UP(src.size(3), block.x);
//grid.z = src.size(1)*src.size(0);
//assert(grid.y < 65536);
//assert(grid.z < 65536);
auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device());
torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options);
...@@ -187,28 +151,14 @@ at::Tensor permute_4D_to0312(at::Tensor src) { ...@@ -187,28 +151,14 @@ at::Tensor permute_4D_to0312(at::Tensor src) {
// to be further specialized for additional archs, if necessary // to be further specialized for additional archs, if necessary
if (ptxv < 100) { if (ptxv < 100) {
//block.y = TRANSP_WARPS_X_TILE_GENERIC;
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst); launch_permute_to0312<TRANSP_WARPS_X_TILE_GENERIC, scalar_t>(src, dst);
})); }));
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_GENERIC>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_generic"); CHECK_ERROR("permute_to0312_k_tile_generic");
} else { } else {
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] {
launch_permute_to0312<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst); launch_permute_to0312<TRANSP_WARPS_X_TILE_SM100, scalar_t>(src, dst);
})); }));
//block.y = TRANSP_WARPS_X_TILE_SM100;
//permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SM100>
// <<<grid, block, 0, stream>>>(src.size(3),
// src.size(1),
// src.size(2),
// src.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
// dst.packed_accessor32<float, 4, torch::RestrictPtrTraits>());
CHECK_ERROR("permute_to0312_k_tile_sm100"); CHECK_ERROR("permute_to0312_k_tile_sm100");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment