Commit fa58767d authored by Mauro Bisson's avatar Mauro Bisson
Browse files

Renamed the template parameter to a simpler name (it's the number of warps per...

Renamed the template parameter to a simpler name (it's the number of warps per tile used in the permutation).
parent 763d4371
...@@ -262,13 +262,13 @@ void permute_to0231_k(const int nchn, ...@@ -262,13 +262,13 @@ void permute_to0231_k(const int nchn,
return; return;
} }
template<int TRANSP_WARPS_X_TILE_SIZE, typename VAL_T> template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0231(at::Tensor src, at::Tensor dst){ void launch_permute_to0231(at::Tensor src, at::Tensor dst){
dim3 block; dim3 block;
dim3 grid; dim3 grid;
block.x = WARP_SIZE; block.x = WARP_SIZE;
block.y = TRANSP_WARPS_X_TILE_SIZE; block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(1), block.x); grid.x = DIV_UP(src.size(1), block.x);
grid.y = DIV_UP(src.size(3), block.x); grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(2)*src.size(0); grid.z = src.size(2)*src.size(0);
...@@ -279,7 +279,7 @@ void launch_permute_to0231(at::Tensor src, at::Tensor dst){ ...@@ -279,7 +279,7 @@ void launch_permute_to0231(at::Tensor src, at::Tensor dst){
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0231_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SIZE> permute_to0231_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(1), <<<grid, block, 0, stream>>>(src.size(1),
src.size(2), src.size(2),
src.size(3), src.size(3),
...@@ -347,13 +347,13 @@ void permute_to0312_k(const int nchn, ...@@ -347,13 +347,13 @@ void permute_to0312_k(const int nchn,
return; return;
} }
template<int TRANSP_WARPS_X_TILE_SIZE, typename VAL_T> template<int WARPS_X_TILE, typename VAL_T>
void launch_permute_to0312(at::Tensor src, at::Tensor dst){ void launch_permute_to0312(at::Tensor src, at::Tensor dst){
dim3 block; dim3 block;
dim3 grid; dim3 grid;
block.x = WARP_SIZE; block.x = WARP_SIZE;
block.y = TRANSP_WARPS_X_TILE_SIZE; block.y = WARPS_X_TILE;
grid.x = DIV_UP(src.size(2), block.x); grid.x = DIV_UP(src.size(2), block.x);
grid.y = DIV_UP(src.size(3), block.x); grid.y = DIV_UP(src.size(3), block.x);
grid.z = src.size(1)*src.size(0); grid.z = src.size(1)*src.size(0);
...@@ -364,7 +364,7 @@ void launch_permute_to0312(at::Tensor src, at::Tensor dst){ ...@@ -364,7 +364,7 @@ void launch_permute_to0312(at::Tensor src, at::Tensor dst){
// get stream // get stream
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
permute_to0312_k<WARP_SIZE, TRANSP_WARPS_X_TILE_SIZE> permute_to0312_k<WARP_SIZE, WARPS_X_TILE>
<<<grid, block, 0, stream>>>(src.size(3), <<<grid, block, 0, stream>>>(src.size(3),
src.size(1), src.size(1),
src.size(2), src.size(2),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment