multi_tensor_apply.cuh 5.07 KB
Newer Older
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
5
#include <c10/cuda/CUDAGuard.h>
6
#include <THC/THC.h>
7
#include "compat.h"
8
9
10
11
12
13
14

#include <assert.h>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.

15
16

// TODO:  Kernel arg size limit may be <4KB for some other cards (ie Jetson)
17
18
19
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};

20
template<int n> struct TensorListMetadata
21
22
23
{
  void* addresses[n][depth_to_max_tensors[n-1]];
  int sizes[depth_to_max_tensors[n-1]];
24
25
  unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
26
  int start_tensor_this_launch;
27
28
29
30
};


template<typename T, typename U, typename... ArgTypes>
31
32
33
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(1024)
#endif
34
35
36
__global__ void multi_tensor_apply_kernel(
    int chunk_size,
    volatile int* noop_flag,
37
    T tl,
38
    U callable,
39
    ArgTypes... args)
40
41
{
  // Hand the chunk information to the user-supplied functor to process however it likes.
42
  callable(chunk_size, noop_flag, tl, args...);
43
44
45
46
47
48
49
50
51
52
53
}

template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
  int block_size,
  int chunk_size,
  const at::Tensor& noop_flag,
  const std::vector<std::vector<at::Tensor>>& tensor_lists,
  T callable,
  ArgTypes... args)
{
54
  TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
55
  int len0 = tensor_lists[0].size();
56
  TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
57
58
59
  auto ref_device = tensor_lists[0][0].device();
  TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
  for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
60
  {
61
    TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
62
63
    for(int t = 0; t < tensor_lists[l].size(); t++)
    {
64
      // TODO:  Print which tensor fails.
65
      bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
66
67
68
69
#ifdef VERSION_GE_1_5
      contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
      TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
70
      TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
71
      TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
72
73
74
75
76
    }
  }

  int ntensors = tensor_lists[0].size();

77
  TensorListMetadata<depth> tl;
78

79
  const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
80
  auto stream = at::cuda::getCurrentCUDAStream();
81

82
  tl.start_tensor_this_launch = 0;
83
84
85
86
87
  int loc_block_info = 0;
  int loc_tensor_info = 0;
  for(int t = 0; t < ntensors; t++)
  {
    tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
88
89
90
91
92
93
94
95
96
    for(int d = 0; d < depth; d++) {
      if (tensor_lists[d][t].is_sparse()) {
        at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
        dst.add_(tensor_lists[d][t]);
        tl.addresses[d][loc_tensor_info] = dst.data_ptr();
      } else {
        tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
      }
    }
97
98
99
100
101
102
103
104
105
106
    loc_tensor_info++;

    int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;

    for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
    {
      // std::cout << chunks_this_tensor << std::endl;
      tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
      tl.block_to_chunk[loc_block_info] = chunk;
      loc_block_info++;
107

108
109
110
111
112
113
114
115
116
      bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
                           chunk == chunks_this_tensor - 1);
      bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
      bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
      if(tensors_full || blocks_full || last_chunk)
      {
        // using accscalar_t = acc_type<scalar_t, true>;
        multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
          chunk_size,
mcarilli's avatar
mcarilli committed
117
          noop_flag.DATA_PTR<int>(),
118
          tl,
119
120
121
122
123
124
125
126
127
128
          callable,
          args...);

        AT_CUDA_CHECK(cudaGetLastError());

        // Reset.  The control flow possibilities here make my brain hurt.
        loc_block_info = 0;
        if(chunk == chunks_this_tensor - 1)
        {
          // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
129
          loc_tensor_info = 0;
130
          tl.start_tensor_this_launch = t + 1;
131
132
133
134
        }
        else
        {
          // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
Michael Carilli's avatar
Michael Carilli committed
135
          tl.sizes[0] = tl.sizes[loc_tensor_info-1];
136
137
138
          for(int d = 0; d < depth; d++)
            tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
          loc_tensor_info = 1;
139
          tl.start_tensor_this_launch = t;
140
141
142
143
144
        }
      }
    }
  }
}