trees.cc 3.88 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include "sccl.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace detect {

#define RANK_TO_INDEX(r) (rank > root ? rank - 1 : rank)

/* Btree which alternates leaves and nodes.
 * Assumes root is 0, which conveniently builds a tree on powers of two,
 * (because we have pow2-1 ranks) which lets us manipulate bits.
 * Find first non-zero bit, then :
 * Find the parent :
 *   xx01[0] -> xx10[0] (1,5,9 below) or xx00[0] if xx10[0] is out of bounds (13 below)
 *   xx11[0] -> xx10[0] (3,7,11 below)
 * Find the children :
 *   xx10[0] -> xx01[0] (2,4,6,8,10,12) or -1 (1,3,5,7,9,11,13)
 *   xx10[0] -> xx11[0] (2,4,6,8,10) or xx101[0] (12) or xx1001[0] ... or -1 (1,3,5,7,9,11,13)
 *
 * Illustration :
 * 0---------------8
 *          ______/ \______
 *         4               12
 *       /   \            /  \
 *     2       6       10     \
 *    / \     / \     /  \     \
 *   1   3   5   7   9   11    13
 */
scclResult_t scclGetBtree(int nranks, int rank, int* u, int* d0, int* d1, int* parentChildType) {
    int up, down0, down1;
    int bit;
    for(bit = 1; bit < nranks; bit <<= 1) {
        if(bit & rank)
            break;
    }

    if(rank == 0) {
        *u  = -1;
        *d0 = -1;
        // Child rank is > 0 so it has to be our child 1, not 0.
        *d1 = nranks > 1 ? bit >> 1 : -1;
        return scclSuccess;
    }

    up = (rank ^ bit) | (bit << 1);
    // if smaller than the parent, we are his first child, otherwise we're his second
    if(up >= nranks)
        up = (rank ^ bit);
    *parentChildType = (rank < up) ? 0 : 1;
    *u               = up;

    int lowbit = bit >> 1;
    // down0 is always within bounds
    down0 = lowbit == 0 ? -1 : rank - lowbit;

    down1 = lowbit == 0 ? -1 : rank + lowbit;
    // Make sure down1 is within bounds
    while(down1 >= nranks) {
        down1 = lowbit == 0 ? -1 : rank + lowbit;
        lowbit >>= 1;
    }
    *d0 = down0;
    *d1 = down1;

    return scclSuccess;
}

/* Build a double binary tree. Take the previous tree for the first tree.
 * For the second tree, we use a mirror tree (if nranks is even)
 *
 * 0---------------8                   3----------------11
 *          ______/ \                 / \______
 *         4         \               /         7
 *       /   \        \             /        /   \
 *     2       6       10         1        5      9
 *    / \     / \     /  \       / \      / \    / \
 *   1   3   5   7   9   11     0   2    4   6  8   10
 *
 * or shift it by one rank (if nranks is odd).
 *
 * 0---------------8            1---------------9
 *          ______/ \______              ______/ \______
 *         4               12           5                0
 *       /   \            /           /   \            /
 *     2       6       10           3       7       11
 *    / \     / \     /  \         / \     / \     /  \
 *   1   3   5   7   9   11       2   4   6   8  10   12
 */
scclResult_t scclGetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* parentChildType0, int* s1, int* d1_0, int* d1_1, int* parentChildType1) {
    // First tree ... use a btree
    scclGetBtree(nranks, rank, s0, d0_0, d0_1, parentChildType0);
    // Second tree ... mirror or shift
    if(nranks % 2 == 1) {
        // shift
        int shiftrank = (rank - 1 + nranks) % nranks;
        int u, d0, d1;
        scclGetBtree(nranks, shiftrank, &u, &d0, &d1, parentChildType1);
        *s1   = u == -1 ? -1 : (u + 1) % nranks;
        *d1_0 = d0 == -1 ? -1 : (d0 + 1) % nranks;
        *d1_1 = d1 == -1 ? -1 : (d1 + 1) % nranks;
    } else {
        // mirror
        int u, d0, d1;
        scclGetBtree(nranks, nranks - 1 - rank, &u, &d0, &d1, parentChildType1);
        *s1   = u == -1 ? -1 : nranks - 1 - u;
        *d1_0 = d0 == -1 ? -1 : nranks - 1 - d0;
        *d1_1 = d1 == -1 ? -1 : nranks - 1 - d1;
    }
    return scclSuccess;
}

} // namespace detect
} // namespace topology
} // namespace hardware
} // namespace sccl