segment_kernel.cu 14.7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
7
8
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
9
#define THREADS 256
rusty1s's avatar
rusty1s committed
10
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
11
12
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
enum ReductionType { ADD, MEAN, MIN, MAX };

#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...)                               \
  [&] {                                                                        \
    if (reduce == "add") {                                                     \
      const ReductionType REDUCE = ADD;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "mean") {                                             \
      const ReductionType REDUCE = MEAN;                                       \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "min") {                                              \
      const ReductionType REDUCE = MIN;                                        \
      return __VA_ARGS__();                                                    \
    } else if (reduce == "max") {                                              \
      const ReductionType REDUCE = MAX;                                        \
      return __VA_ARGS__();                                                    \
    }                                                                          \
  }()

template <typename scalar_t, ReductionType REDUCE> struct Reducer {
  static inline __host__ __device__ scalar_t init() {
    if (REDUCE == MIN) {
      return std::numeric_limits<scalar_t>::max();
    } else if (REDUCE == MAX) {
      return std::numeric_limits<scalar_t>::min();
    } else {
      return (scalar_t)0;
    }
  }

  static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
                                                int64_t *arg, int64_t new_arg) {
    if ((REDUCE == MIN && new_val < *val) ||
        (REDUCE == MAX && new_val > *val)) {
      *val = new_val;
      *arg = new_arg;
    } else {
      *val = *val + new_val;
    }
  }

  static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
                                               int64_t *arg_address,
                                               int64_t arg, int count) {
    if (REDUCE == ADD) {
      *address = val;
    } else if (REDUCE == MEAN) {
      *address = val / (scalar_t)max(count, 1);
    } else if (REDUCE == MIN || REDUCE == MAX) {
      if (count > 0) {
        *address = val;
        *arg_address = arg;
      } else {
        *address = (scalar_t)0;
      }
    }
  }
rusty1s's avatar
atomics  
rusty1s committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

  static inline __device__ void atom_write(scalar_t *address, scalar_t val,
                                           int64_t *arg_address, int64_t arg) {
    if (REDUCE == ADD) {
      atomAdd(address, val);
    } else if (REDUCE == MEAN) {
      atomAdd(address, val);
    } else if (REDUCE == MIN && val < *address) {
      atomMin(address, val);
    } else if (REDUCE == MAX && val > *address) {
      atomMax(address, val);
    }

    if (REDUCE == MIN || REDUCE == MAX) {
      __syncthreads();
      if (*address == val) {
        *arg_address = arg;
      }
    }
  }
rusty1s's avatar
rusty1s committed
90
};
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
atomics  
rusty1s committed
92
93
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
rusty1s's avatar
rusty1s committed
94
95
96
97
template <typename scalar_t> struct IndexPtrToOffset {
  static inline __host__ __device__ int
  get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
    int offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
98
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
99
100
101
102
103
104
105
106
107
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
108
109
110
111
112
113
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t N,
                   size_t E) {
rusty1s's avatar
rusty1s committed
114

rusty1s's avatar
atomics  
rusty1s committed
115
116
  // Each warp processes exactly `32/TB` rows and aggregates all row values
  // via a parallel reduction.
rusty1s's avatar
rusty1s committed
117

rusty1s's avatar
rusty1s committed
118
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
119
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
120
121
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
122
  if (row_idx < N) {
rusty1s's avatar
rusty1s committed
123
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
124
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
125
126
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
127

rusty1s's avatar
rusty1s committed
128
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
rusty1s's avatar
atomics  
rusty1s committed
129
    int64_t arg, arg_tmp;
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
rusty1s committed
131
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
132
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
133
134
      Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
                                        src_idx);
rusty1s's avatar
rusty1s committed
135
136
137
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
138
139
    for (int i = TB / 2; i > 0; i /= 2) {
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
140
      if (REDUCE == MIN || REDUCE == MAX) {
rusty1s's avatar
atomics  
rusty1s committed
141
        arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
rusty1s's avatar
rusty1s committed
142
      }
rusty1s's avatar
rusty1s committed
143
      Reducer<scalar_t, REDUCE>::update(
rusty1s's avatar
atomics  
rusty1s committed
144
          &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
rusty1s's avatar
rusty1s committed
145
    }
rusty1s's avatar
rusty1s committed
146
147

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
148
149
150
      Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
                                       arg_out_data + row_idx, arg,
                                       row_end - row_start);
rusty1s's avatar
rusty1s committed
151
152
153
154
    }
  }
}

rusty1s's avatar
rusty1s committed
155
156
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
rusty1s's avatar
rusty1s committed
157
158
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
rusty1s's avatar
rusty1s committed
159
    scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
rusty1s's avatar
rusty1s committed
160

rusty1s's avatar
rusty1s committed
161
162
163
  // Each thread processes exactly one row. It turned out that is more
  // efficient than using shared memory due to avoiding synchronization
  // barriers.
rusty1s's avatar
rusty1s committed
164

rusty1s's avatar
rusty1s committed
165
166
167
168
169
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
rusty1s's avatar
rusty1s committed
170
    int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
171
172
173
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
174

rusty1s's avatar
rusty1s committed
175
176
    scalar_t val = Reducer<scalar_t, REDUCE>::init();
    int64_t arg;
rusty1s's avatar
rusty1s committed
177
178
179

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
rusty1s's avatar
rusty1s committed
180
181
      Reducer<scalar_t, REDUCE>::update(
          &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
rusty1s's avatar
rusty1s committed
182
183
    }

rusty1s's avatar
rusty1s committed
184
185
186
    Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
                                     arg_out_data + thread_idx, arg,
                                     row_end - row_start);
rusty1s's avatar
rusty1s committed
187
188
189
  }
}

rusty1s's avatar
rusty1s committed
190
191
192
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
                 at::optional<at::Tensor> out_opt, std::string reduce) {
193

rusty1s's avatar
rusty1s committed
194
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
195
196
197
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
198
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
199
  auto reduce_dim = indptr.dim() - 1;
200
201
202

  at::Tensor out;
  if (out_opt.has_value()) {
rusty1s's avatar
rusty1s committed
203
    out = out_opt.value().contiguous();
204
205
206
207
208
209
210
211
212
    for (int i = 0; i < out.dim(); i++)
      if (i != reduce_dim)
        AT_ASSERTM(src.size(i) == out.size(i));
    AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
  } else {
    auto sizes = src.sizes().vec();
    sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
    out = at::empty(sizes, src.options());
  }
rusty1s's avatar
rusty1s committed
213

rusty1s's avatar
rusty1s committed
214
  at::optional<at::Tensor> arg_out = at::nullopt;
rusty1s's avatar
rusty1s committed
215
  int64_t *arg_out_data = nullptr;
rusty1s's avatar
rusty1s committed
216
217
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
rusty1s's avatar
rusty1s committed
218
    arg_out_data = arg_out.value().DATA_PTR<int64_t>();
rusty1s's avatar
rusty1s committed
219
220
  }

rusty1s's avatar
rusty1s committed
221
222
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
223
  auto E = src.size(reduce_dim);
rusty1s's avatar
rusty1s committed
224

rusty1s's avatar
rusty1s committed
225
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
226
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
227
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
228
229
230
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
231
232
233
234
235
236
237
238
239
240
241
    AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
      if (K == 1) {
        segment_csr_kernel<scalar_t, REDUCE, 1>
            <<<BLOCKS(32, N), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, E);
      } else {
        segment_csr_broadcast_kernel<scalar_t, REDUCE>
            <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
                src_data, indptr_info, out_data, arg_out_data, N, K, E);
      }
    });
rusty1s's avatar
rusty1s committed
242
243
  });

rusty1s's avatar
rusty1s committed
244
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
245
246
}

rusty1s's avatar
rusty1s committed
247
248
249
250
251
template <typename scalar_t, ReductionType REDUCE>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
                   const at::cuda::detail::TensorInfo<int64_t, int> index_info,
                   scalar_t *out_data, int64_t *arg_out_data, size_t E) {
rusty1s's avatar
rusty1s committed
252

rusty1s's avatar
rusty1s committed
253
254
255
256
257
258
  // Each thread processes exactly one entry. Within a warp, we perform a
  // parallel reduction across equal indices, and write the intermediate
  // result via atomics.

  int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int lane_idx = row_idx & (32 - 1);
rusty1s's avatar
rusty1s committed
259

rusty1s's avatar
rusty1s committed
260
261
262
263
  if (row_idx < E) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_idx, index_info);
    int idx = index_info.data[offset], next_idx;
rusty1s's avatar
atomics  
rusty1s committed
264

rusty1s's avatar
rusty1s committed
265
    scalar_t val = src_data[row_idx], tmp;
rusty1s's avatar
atomics  
rusty1s committed
266
    int64_t arg = row_idx % index_info.sizes[index_info.dims - 1], arg_tmp;
rusty1s's avatar
rusty1s committed
267
268

#pragma unroll
rusty1s's avatar
rusty1s committed
269
    for (int i = 1; i < 32; i *= 2) {
rusty1s's avatar
atomics  
rusty1s committed
270
      // Parallel reduction inside a single warp.
rusty1s's avatar
rusty1s committed
271
      tmp = __shfl_up_sync(FULL_MASK, val, i);
rusty1s's avatar
atomics  
rusty1s committed
272
273
274
      if (REDUCE == MIN || REDUCE == MAX) {
        arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
      }
rusty1s's avatar
rusty1s committed
275
      next_idx = __shfl_up_sync(FULL_MASK, idx, i);
276
      assert(idx >= next_idx);
rusty1s's avatar
rusty1s committed
277
      if (lane_idx >= i && idx == next_idx)
rusty1s's avatar
atomics  
rusty1s committed
278
        Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
rusty1s's avatar
rusty1s committed
279
280
    }

rusty1s's avatar
rusty1s committed
281
282
    next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
    if (lane_idx == 32 - 1 || idx != next_idx) {
rusty1s's avatar
atomics  
rusty1s committed
283
284
      Reducer<scalar_t, REDUCE>::atom_write(out_data + idx, val,
                                            arg_out_data + idx, arg);
rusty1s's avatar
rusty1s committed
285
286
287
288
    }
  }
}

rusty1s's avatar
rusty1s committed
289
290
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
rusty1s's avatar
rusty1s committed
291
292
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> index_info,
rusty1s's avatar
rusty1s committed
293
    scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
rusty1s's avatar
rusty1s committed
294

rusty1s's avatar
rusty1s committed
295
296
297
  // Each thread processes a single column and `TB` index entries. Coalesced
  // read and write is performed in column-major order. The intermediate
  // results are written via atomics.
rusty1s's avatar
rusty1s committed
298

rusty1s's avatar
rusty1s committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
  int col_idx = blockIdx.y * blockDim.x + threadIdx.x;

  if (row_start < E && col_idx < K) {
    int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
        row_start, index_info);

    int idx1 = __ldg(index_info.data + offset);
    scalar_t val = src_data[K * row_start + col_idx];

#pragma unroll
    for (int i = 1; i < TB; i++) {
      if (row_start + i >= E)
        break;

      int idx2 = __ldg(index_info.data + offset +
                       i * index_info.strides[index_info.dims - 1]);
316
      assert(idx1 <= idx2);
rusty1s's avatar
rusty1s committed
317
318
319
320
321
322
323
324
325
326
327
      if (idx1 == idx2) {
        val += src_data[K * (row_start + i) + col_idx];
      } else {
        atomAdd(out_data + K * idx1 + col_idx, val);
        val = src_data[K * (row_start + i) + col_idx];
      }
      idx1 = idx2;
    }

    atomAdd(out_data + K * idx1 + col_idx, val);
  }
rusty1s's avatar
rusty1s committed
328
329
}

rusty1s's avatar
rusty1s committed
330
331
332
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
                 std::string reduce) {
rusty1s's avatar
rusty1s committed
333
334
335
336
337
  AT_ASSERTM(src.dim() >= index.dim());
  for (int i = 0; i < index.dim(); i++)
    AT_ASSERTM(src.size(i) == index.size(i));

  src = src.contiguous();
rusty1s's avatar
rusty1s committed
338
  out = out.contiguous();
rusty1s's avatar
rusty1s committed
339
  auto reduce_dim = index.dim() - 1;
rusty1s's avatar
rusty1s committed
340

rusty1s's avatar
rusty1s committed
341
342
343
  for (int i = 0; i < out.dim(); i++)
    if (i != reduce_dim)
      AT_ASSERTM(src.size(i) == out.size(i));
rusty1s's avatar
rusty1s committed
344

rusty1s's avatar
rusty1s committed
345
346
347
348
349
  at::optional<at::Tensor> arg_out = at::nullopt;
  if (reduce == "min" || reduce == "max") {
    arg_out = at::full_like(out, src.size(reduce_dim), index.options());
  }

rusty1s's avatar
rusty1s committed
350
351
  auto E = index.numel();
  auto K = src.numel() / index.numel();
rusty1s's avatar
rusty1s committed
352
  auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
353

rusty1s's avatar
rusty1s committed
354
355
  auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
356
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
rusty1s's avatar
rusty1s committed
357
358
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();
rusty1s's avatar
rusty1s committed
359

rusty1s's avatar
rusty1s committed
360
361
362
    // Select the right kernel based on average row length (purely heuristic)
    // and whether we need broadcasting capabilties (K > 1):

rusty1s's avatar
rusty1s committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    if (K == 1 && reduce == "add") {
      segment_coo_kernel<scalar_t, ADD><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, nullptr, E);
    } else if (K == 1 && reduce == "mean") {
      segment_coo_kernel<scalar_t, MEAN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, nullptr, E);
    } else if (K == 1 && reduce == "min") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_coo_kernel<scalar_t, MIN><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, arg_out_data, E);
    } else if (K == 1 && reduce == "max") {
      auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
      segment_coo_kernel<scalar_t, MAX><<<BLOCKS(1, E), THREADS, 0, stream>>>(
          src_data, index_info, out_data, arg_out_data, E);
    } else if (avg_len <= 8)
      segment_coo_broadcast_kernel<scalar_t, ADD, 4>
rusty1s's avatar
rusty1s committed
379
          <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
380
381
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 16)
rusty1s's avatar
rusty1s committed
382
      segment_coo_broadcast_kernel<scalar_t, ADD, 8>
rusty1s's avatar
rusty1s committed
383
          <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
rusty1s's avatar
rusty1s committed
384
385
             stream>>>(src_data, index_info, out_data, nullptr, E, K);
    else if (avg_len <= 32)
rusty1s's avatar
rusty1s committed
386
      segment_coo_broadcast_kernel<scalar_t, ADD, 16>
rusty1s's avatar
rusty1s committed
387
          <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
388
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
389
    else
rusty1s's avatar
rusty1s committed
390
      segment_coo_broadcast_kernel<scalar_t, ADD, 32>
rusty1s's avatar
rusty1s committed
391
          <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
rusty1s's avatar
rusty1s committed
392
             0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
rusty1s's avatar
rusty1s committed
393
  });
394

rusty1s's avatar
atomics  
rusty1s committed
395
396
397
398
  if (reduce == "mean") {
    AT_ASSERTM(false); // TODO: DIVIDE ENTRIES.
  }

rusty1s's avatar
rusty1s committed
399
  return std::make_tuple(out, arg_out);
rusty1s's avatar
rusty1s committed
400
}