infiniccl_test.cpp 6.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include "infiniccl_test.hpp"

#include <chrono>
#include <cstring>
#include <iostream>
#include <numeric>
#include <pthread.h>
#include <vector>

#define TEST_INFINI(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return 1)
#define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr)

YdrMaster's avatar
YdrMaster committed
13
const size_t MAX_COUNT = 8ULL * 1024 * 1024;
14
// const size_t MAX_COUNT = 512 * 1024; // for metax
15
16
17
18
19
20
21
22

const size_t TEST_COUNTS[] = {
    128,
    1024,
    4 * 1024,
    MAX_COUNT,
};

23
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16};
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

const size_t WARM_UPS = 10;

const size_t ITERATIONS = 100;

struct ThreadArgs {
    int rank;
    int device_id;
    infinicclComm_t comm;
    infiniDevice_t device_type;
    infiniDtype_t dtype;
    size_t count;
    const void *data;
    const void *ans;
    int *result;
    double *time;
};

void setData(infiniDtype_t dtype, void *data, size_t count, float val) {
    switch (dtype) {
    case INFINI_DTYPE_F32:
        for (size_t i = 0; i < count; i++) {
            ((float *)data)[i] = val;
        }
        break;

    case INFINI_DTYPE_F16:
        for (size_t i = 0; i < count; i++) {
            ((fp16_t *)data)[i] = utils::cast<fp16_t>(val);
        }
        break;
55
56
57
58
59
    case INFINI_DTYPE_BF16:
        for (size_t i = 0; i < count; i++) {
            ((bf16_t *)data)[i] = utils::cast<bf16_t>(val);
        }
        break;
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    default:
        std::abort();
        break;
    }
}

template <typename T>
int checkData(const T *actual_, const T *expected_, size_t count) {
    int failed = 0;
    for (size_t i = 0; i < count; i++) {
        if constexpr (std::is_same<T, fp16_t>::value) {
            float actual = utils::cast<float>(actual_[i]);
            float expected = utils::cast<float>(expected_[i]);
            if (std::abs(actual - expected) > 1e-4) {
                failed += 1;
            }
76
77
78
79
80
81
        } else if constexpr (std::is_same<T, bf16_t>::value) {
            float actual = utils::cast<float>(actual_[i]);
            float expected = utils::cast<float>(expected_[i]);
            if (std::abs(actual - expected) > 1e-4) {
                failed += 1;
            }
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        } else {
            if (std::abs(actual_[i] - expected_[i]) > 1e-4) {
                failed += 1;
            }
        }
    }
    return failed;
}

int checkData(const void *actual, const void *expected, infiniDtype_t dtype, size_t count) {
    switch (dtype) {
    case INFINI_DTYPE_F32:
        return checkData((const float *)actual, (const float *)expected, count);
    case INFINI_DTYPE_F16:
        return checkData((const fp16_t *)actual, (const fp16_t *)expected, count);
97
98
    case INFINI_DTYPE_BF16:
        return checkData((const bf16_t *)actual, (const bf16_t *)expected, count);
99
100
101
102
103
104
105
106
107
108
    default:
        std::abort();
        return 1;
    }
}

void *testAllReduceThread(void *arg) {
    ThreadArgs *args = (ThreadArgs *)arg;
    *(args->result) = 1;
    TEST_INFINI_THREAD(infinirtSetDevice(args->device_type, args->device_id));
Pan Zezhong's avatar
Pan Zezhong committed
109
110
    infinirtStream_t stream;
    TEST_INFINI_THREAD(infinirtStreamCreate(&stream));
111
112
113
114
115
    void *output = std::malloc(args->count * infiniSizeOf(args->dtype));
    std::memset(output, 0, args->count * infiniSizeOf(args->dtype));
    void *buf;
    TEST_INFINI_THREAD(infinirtMalloc(&buf, args->count * infiniSizeOf(args->dtype)));
    TEST_INFINI_THREAD(infinirtMemcpy(buf, args->data, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_H2D));
Pan Zezhong's avatar
Pan Zezhong committed
116
    TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
117
118
119
120
121
122
    TEST_INFINI_THREAD(infinirtDeviceSynchronize());
    TEST_INFINI_THREAD(infinirtMemcpy(output, buf, args->count * infiniSizeOf(args->dtype), INFINIRT_MEMCPY_D2H));

    if (checkData(output, args->ans, args->dtype, args->count) != 0) {
        std::free(output);
        infinirtFree(buf);
Pan Zezhong's avatar
Pan Zezhong committed
123
        infinirtStreamDestroy(stream);
124
125
126
        return nullptr;
    }
    for (size_t i = 0; i < WARM_UPS; i++) {
Pan Zezhong's avatar
Pan Zezhong committed
127
        TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
128
129
130
131
132
133
    }
    TEST_INFINI_THREAD(infinirtDeviceSynchronize());

    // measure time
    auto start = std::chrono::high_resolution_clock::now();
    for (size_t i = 0; i < ITERATIONS; i++) {
Pan Zezhong's avatar
Pan Zezhong committed
134
        TEST_INFINI_THREAD(infinicclAllReduce(buf, buf, args->count, args->dtype, INFINICCL_SUM, args->comm, stream));
135
136
137
138
139
140
141
142
143
144
    }
    TEST_INFINI_THREAD(infinirtDeviceSynchronize());
    auto end = std::chrono::high_resolution_clock::now();
    double elapsed_ms = std::chrono::duration<double, std::milli>(end - start).count();
    *args->time = elapsed_ms / ITERATIONS;

    *args->result = 0;

    std::free(output);
    infinirtFree(buf);
Pan Zezhong's avatar
Pan Zezhong committed
145
    infinirtStreamDestroy(stream);
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    return nullptr;
}

int testAllReduce(infiniDevice_t device_type, int ndevice) {
    std::vector<ThreadArgs> thread_args(ndevice);
    std::vector<infinicclComm_t> comms(ndevice);
    std::vector<pthread_t> threads(ndevice);
    std::vector<int> device_ids(ndevice);
    std::vector<int> results(ndevice);
    std::vector<double> times(ndevice);
    void *data = std::malloc(MAX_COUNT * sizeof(float)); // Use float as max dtype size
    void *ans = std::malloc(MAX_COUNT * sizeof(float));

    for (int i = 0; i < ndevice; i++) {
        device_ids[i] = i;
    }
    TEST_INFINI(infinicclCommInitAll(device_type, comms.data(), ndevice, device_ids.data()));

    for (infiniDtype_t dtype : TEST_DTYPES) {
        setData(dtype, data, MAX_COUNT, 1.0f);
        setData(dtype, ans, MAX_COUNT, 1.0f * ndevice);
        for (size_t count : TEST_COUNTS) {
            std::cout << "Testing AllReduce with " << count << " elements of " << infiniDtypeToString(dtype) << std::endl;
            for (int rank = 0; rank < ndevice; rank++) {
                thread_args[rank] = {rank, device_ids[rank], comms[rank], device_type, dtype, count, data, ans, &results[rank], &times[rank]};
                pthread_create(&threads[rank], NULL, testAllReduceThread, &thread_args[rank]);
            }
            for (int rank = 0; rank < ndevice; rank++) {
                pthread_join(threads[rank], NULL);
            }
            int failed = std::accumulate(results.begin(), results.end(), 0);
            for (int rank = 0; rank < ndevice; rank++) {
                if (results[rank] != 0) {
                    std::cout << "Rank " << rank << ": incorrect results." << std::endl;
                } else {
                    std::cout << "Rank " << rank << ": " << times[rank] << " ms." << std::endl;
                }
            }

            if (failed > 0) {
                std::cout << "Failed with " << failed << " errors." << std::endl
                          << std::endl;
                std::free(data);
                std::free(ans);
                return 1;
            }
            std::cout << std::endl;
        }
    }

    std::free(data);
    std::free(ans);
    return 0;
}