multi_tensor_apply.cuh 5.12 KB
Newer Older
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
5
#include <THC/THC.h>
6
#include "compat.h"
7
8
9
10
11
12
13

#include <assert.h>

// #include <iostream>

// This header is the one-stop shop for all your multi-tensor apply needs.

14
15

// TODO:  Kernel arg size limit may be <4KB for some other cards (ie Jetson)
16
17
18
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};

19
template<int n> struct TensorListMetadata
20
21
22
{
  void* addresses[n][depth_to_max_tensors[n-1]];
  int sizes[depth_to_max_tensors[n-1]];
23
24
  unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
  int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
25
  int start_tensor_this_launch;
26
27
28
29
30
31
32
};


template<typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
    int chunk_size,
    volatile int* noop_flag,
33
    T* tl,
34
    U callable,
35
    ArgTypes... args)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
{
  // Hand the chunk information to the user-supplied functor to process however it likes.
  callable(chunk_size, noop_flag, tl, args...); 
}

template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
  int block_size,
  int chunk_size,
  const at::Tensor& noop_flag,
  const std::vector<std::vector<at::Tensor>>& tensor_lists,
  T callable,
  ArgTypes... args)
{
50
  TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
51
  int len0 = tensor_lists[0].size();
52
  TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
53
54
55

  for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
  {
56
    TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
57
58
    for(int t = 0; t < tensor_lists[l].size(); t++)
    {
59
      // TODO:  Print which tensor fails.
60
      bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous();
61
62
63
64
#ifdef VERSION_GE_1_5
      contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
      TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
65
66
      TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
      TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
67
68
69
70
71
    }
  }

  int ntensors = tensor_lists[0].size();

72
  TensorListMetadata<depth> tl;
73
74
75

  auto stream = at::cuda::getCurrentCUDAStream();
  
76
  tl.start_tensor_this_launch = 0;
77
78
79
80
81
  int loc_block_info = 0;
  int loc_tensor_info = 0;
  for(int t = 0; t < ntensors; t++)
  {
    tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
82
83
84
85
86
87
88
89
90
    for(int d = 0; d < depth; d++) {
      if (tensor_lists[d][t].is_sparse()) {
        at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided));
        dst.add_(tensor_lists[d][t]);
        tl.addresses[d][loc_tensor_info] = dst.data_ptr();
      } else {
        tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
      }
    }
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    loc_tensor_info++;

    int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;

    for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
    {
      // std::cout << chunks_this_tensor << std::endl;
      tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
      tl.block_to_chunk[loc_block_info] = chunk;
      loc_block_info++;
  
      bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
                           chunk == chunks_this_tensor - 1);
      bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
      bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
      if(tensors_full || blocks_full || last_chunk)
      {
108
109
110
111
        auto storage = at::empty(sizeof(tl), c10::TensorOptions(at::kStrided).dtype(at::kByte).device(at::kCPU).pinned_memory(true));
        auto tl_as_host_pinned_ptr = static_cast<decltype(tl)*>(storage.data_ptr());
        memcpy(tl_as_host_pinned_ptr, &tl, sizeof(tl));
        AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(tl_as_host_pinned_ptr, stream));
112
113
114
        // using accscalar_t = acc_type<scalar_t, true>;
        multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
          chunk_size,
mcarilli's avatar
mcarilli committed
115
          noop_flag.DATA_PTR<int>(),
116
          tl_as_host_pinned_ptr,
117
118
119
120
121
122
123
124
125
126
127
          callable,
          args...);

        AT_CUDA_CHECK(cudaGetLastError());

        // Reset.  The control flow possibilities here make my brain hurt.
        loc_block_info = 0;
        if(chunk == chunks_this_tensor - 1)
        {
          // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
          loc_tensor_info = 0; 
128
          tl.start_tensor_this_launch = t + 1;
129
130
131
132
        }
        else
        {
          // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
Michael Carilli's avatar
Michael Carilli committed
133
          tl.sizes[0] = tl.sizes[loc_tensor_info-1];
134
135
136
          for(int d = 0; d < depth; d++)
            tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
          loc_tensor_info = 1;
137
          tl.start_tensor_this_launch = t;
138
139
140
141
142
        }
      }
    }
  }
}