multi_tensor_apply.cuh 4.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <assert.h>
#include <cuda_runtime.h>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.

14
15

// TODO:  Kernel arg size limit may be <4KB for some other cards (ie Jetson)
16
17
18
19
20
21
22
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};

template<int n> struct TensorList
{
  void* addresses[n][depth_to_max_tensors[n-1]];
  int sizes[depth_to_max_tensors[n-1]];
23
24
  unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
};


template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
    int chunk_size,
    volatile int* noop_flag,
    T tl,
    U callable,
    ArgTypes... args) // in_t** in, float** out, float scale
{
  // Hand the chunk information to the user-supplied functor to process however it likes.
  callable(chunk_size, noop_flag, tl, args...); 
}

template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
  int block_size,
  int chunk_size,
  const at::Tensor& noop_flag,
  const std::vector<std::vector<at::Tensor>>& tensor_lists,
  T callable,
  ArgTypes... args)
{
49
  AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
50
51
52
53
54
55
56
57
  int len0 = tensor_lists[0].size();
  AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");

  for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
  {
    AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
    for(int t = 0; t < tensor_lists[l].size(); t++)
    {
58
      // TODO:  Print which tensor fails.
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
      AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
      AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
      AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
    }
  }

  int ntensors = tensor_lists[0].size();

  TensorList<depth> tl;

  auto stream = at::cuda::getCurrentCUDAStream();
  
  int loc_block_info = 0;
  int loc_tensor_info = 0;
  for(int t = 0; t < ntensors; t++)
  {
    tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
    for(int d = 0; d < depth; d++)
      tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
    loc_tensor_info++;

    int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;

    for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
    {
      // std::cout << chunks_this_tensor << std::endl;
      tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
      tl.block_to_chunk[loc_block_info] = chunk;
      loc_block_info++;
  
      bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
                           chunk == chunks_this_tensor - 1);
      bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
      bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
      if(tensors_full || blocks_full || last_chunk)
      {
        // using accscalar_t = acc_type<scalar_t, true>;
        multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
          chunk_size,
          noop_flag.data<int>(),
          tl,
          callable,
          args...);

        AT_CUDA_CHECK(cudaGetLastError());

        // Reset.  The control flow possibilities here make my brain hurt.
        loc_block_info = 0;
        if(chunk == chunks_this_tensor - 1)
        {
          // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
          loc_tensor_info = 0; 
        }
        else
        {
          // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
Michael Carilli's avatar
Michael Carilli committed
115
          tl.sizes[0] = tl.sizes[loc_tensor_info-1];
116
117
118
119
120
121
122
123
          for(int d = 0; d < depth; d++)
            tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
          loc_tensor_info = 1;
        }
      }
    }
  }
}