"vscode:/vscode.git/clone" did not exist on "f61f340279661582f96f315c2e8d0a798587128d"
threadblock_swizzle.h 1.9 KB
Newer Older
1
2
3
4
5
6
7
8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include "common.h"

namespace tl {

9
template <int panel_width> TL_DEVICE dim3 rasterization2DRow() {
10
11
12
13
14
15
16
  const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
  const unsigned int grid_size = gridDim.x * gridDim.y;
  const unsigned int panel_size = panel_width * gridDim.x;
  const unsigned int panel_offset = block_idx % panel_size;
  const unsigned int panel_idx = block_idx / panel_size;
  const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
  const unsigned int stride =
17
18
19
20
21
22
      panel_idx + 1 < total_panel
          ? panel_width
          : (grid_size - panel_idx * panel_size) / gridDim.x;
  const unsigned int col_idx = (panel_idx & 1)
                                   ? gridDim.x - 1 - panel_offset / stride
                                   : panel_offset / stride;
23
24
25
26
  const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
  return {col_idx, row_idx, blockIdx.z};
}

27
template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() {
28
29
30
31
32
33
34
  const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
  const unsigned int grid_size = gridDim.x * gridDim.y;
  const unsigned int panel_size = panel_width * gridDim.y;
  const unsigned int panel_offset = block_idx % panel_size;
  const unsigned int panel_idx = block_idx / panel_size;
  const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
  const unsigned int stride =
35
36
37
38
39
40
      panel_idx + 1 < total_panel
          ? panel_width
          : (grid_size - panel_idx * panel_size) / gridDim.y;
  const unsigned int row_idx = (panel_idx & 1)
                                   ? gridDim.y - 1 - panel_offset / stride
                                   : panel_offset / stride;
41
42
43
44
  const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
  return {col_idx, row_idx, blockIdx.z};
}

45
} // namespace tl