"docs/source/en/api/loaders.md" did not exist on "fdf70cb54beb7df77bb46c3227b345115a96e505"
bcast.cc 3.08 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file kernel/bcast.h
 * @brief Broadcast related function implementations.
5
6
7
 */
#include <dgl/bcast.h>
#include <dmlc/logging.h>
8

9
10
11
12
13
#include <algorithm>

namespace dgl {

namespace {
14
/**
15
 * @brief Determine whether use broadcasting or not, given the operator
16
17
18
 *        type, lhs array and rhs array.
 */
bool UseBcast(const std::string& op, NDArray lhs, NDArray rhs) {
19
  if (op == "copy_lhs" || op == "copy_rhs")
20
    return false;  // broadcasting is not required for copy_u/copy_e
21
  if (lhs->ndim != rhs->ndim) return true;
22
  for (int i = 1; i < lhs->ndim; ++i) {
23
    if (lhs->shape[i] != rhs->shape[i]) return true;
24
25
26
27
28
29
  }
  return false;
}

}  // namespace

30
/**
31
 * @brief: Compute broadcast and auxiliary information given operator
32
 *         and operands for kernel computation.
33
 * @note: Expect lhs, rhs to have ndim >= 2 and the shape of lhs/rhs
34
35
36
37
38
39
 *        valid for the op computation.
 */
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {
  BcastOff rst;
  rst.lhs_len = 1;
  rst.rhs_len = 1;
40
41
  for (int i = 1; i < lhs->ndim; ++i) rst.lhs_len *= lhs->shape[i];
  for (int i = 1; i < rhs->ndim; ++i) rst.rhs_len *= rhs->shape[i];
42
43
44
45
46
47
48
  rst.use_bcast = UseBcast(op, lhs, rhs);
  rst.reduce_size = 1;  // defaults to 1, except for the case op == 'dot'.
  if (rst.use_bcast) {
    const int max_ndim = std::max(lhs->ndim, rhs->ndim) - 1;
    int out_len = 1, j = 0;
    if (op == "dot") {
      rst.reduce_size = lhs->shape[lhs->ndim - 1];  // set reduce_size for dot.
49
50
      ++j;  // do not consider reduce axis in computing lhs_offset and
            // rhs_offset.
51
52
53
54
    }
    int stride_l = 1, stride_r = 1;
    rst.lhs_offset.push_back(0);  // lhs_offset[0] is always 0
    rst.rhs_offset.push_back(0);  // rhs_offset[0] is always 0
55
56
57
58
59
60
61
    for (; j < max_ndim; ++j) {   // iterate the axis from back to front.
      // dl refers to the size of lhs array in the current axis, likewise for
      // dr.
      const int dl =
          (lhs->ndim - 1 - j < 1) ? 1 : lhs->shape[lhs->ndim - 1 - j];
      const int dr =
          (rhs->ndim - 1 - j < 1) ? 1 : rhs->shape[rhs->ndim - 1 - j];
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
      for (int i = 1; i < std::max(dl, dr); ++i) {
        for (int k = 0; k < out_len; ++k) {
          /* Explaination:
           * if current dimension is not broadcast dimension for lhs array
           *   lhs_offset[i * out_len + k] = lhs_offset[k] + i * stride_l
           * else
           *   lhs_offset[i * out_len + k] = lhs_offset[k]
           * likewise for rhs_offset.
           */
          rst.lhs_offset.push_back(rst.lhs_offset[k] + i * (i < dl) * stride_l);
          rst.rhs_offset.push_back(rst.rhs_offset[k] + i * (i < dr) * stride_r);
        }
      }
      out_len *= std::max(dl, dr);
      stride_l *= dl;
      stride_r *= dr;
    }
    rst.out_len = out_len;
  } else {
81
    rst.out_len = (op == "copy_rhs") ? rst.rhs_len : rst.lhs_len;
82
    if (op == "dot") {
83
84
85
86
      // set reduce_size for dot.
      rst.reduce_size = lhs->shape[lhs->ndim - 1];
      // out_len is divied by reduce_size in dot.
      rst.out_len /= rst.reduce_size;
87
88
89
90
91
92
    }
  }
  return rst;
}

}  // namespace dgl