comm_manager.h 699 Bytes
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
#ifndef COMM_MANAGER_H
#define COMM_MANAGER_H

Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
9
10
11
#define NCCL_SAFE_CALL(__fn__) { \
	auto __res__ = __fn__; \
	if (__res__ != ncclSuccess) { \
		fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
		exit(-1); \
	} \
}

Rick Ho's avatar
Rick Ho committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <mpi.h>
#include "nccl.h"

struct CommManager {
	int rank, size;
	ncclComm_t ncclcomm;

	CommManager() {
		MPI_Comm_rank(MPI_COMM_WORLD, &rank);
		MPI_Comm_size(MPI_COMM_WORLD, &size);

		ncclUniqueId uid;
		if (rank == 0) {
			ncclGetUniqueId(&uid);
		}
		MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
Rick Ho's avatar
Rick Ho committed
28
		NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
Rick Ho's avatar
Rick Ho committed
29
30
31
32
33
34
	}
};

CommManager* getCommManager();

#endif  // COMM_MANAGER