Unverified Commit 898269fa authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

ibgda: support non-bond dual-port environments

ibgda: support non-bond dual-port environments via multi-port config
parents b0f13ef7 1cd5eea6
......@@ -77,7 +77,7 @@ __device__ static __forceinline__
nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) {
auto state = ibgda_get_state();
const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe;
return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe];
return &state->globalmem.rcs[pe * num_rc_per_pe * state->num_devices_initialized + id % (num_rc_per_pe * state->num_devices_initialized)];
}
__device__ static __forceinline__
......@@ -199,20 +199,22 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v
__device__ static __forceinline__
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
auto log2_cumem_granularity = state->log2_cumem_granularity;
// Local key
uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity;
uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx;
auto device_key = state->constmem.lkeys[idx];
auto lchunk_size = device_key.next_addr - laddr;
*lkey = device_key.key;
// Remote key
uint64_t roffset = raddr - heap_start;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized
+ dst_pe * state->num_devices_initialized + dev_idx;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
device_key = state->constmem.rkeys[idx];
} else {
......@@ -227,12 +229,13 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
}
__device__ static __forceinline__ void
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) {
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
uint64_t roffset = addr - heap_start;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized)
+ dst_pe * state->num_devices_initialized + dev_idx;
nvshmemi_ibgda_device_key_t device_key;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
device_key = state->constmem.rkeys[idx];
......@@ -261,10 +264,10 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
// NOTES: the `p` operation will not cross multiple remote chunks
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey);
// Write WQEs
auto qp = ibgda_get_rc(dst_pe, qp_id);
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx);
// Write WQEs
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs;
wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx);
......@@ -336,11 +339,21 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
uint64_t my_raddr = 0;
uint64_t my_chunk_size = 0;
auto qp = ibgda_get_rc(dst_pe, qp_id);
// Decide how many messages (theoretically 3 for maximum)
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes)
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey));
if (lane_id == num_wqes) {
my_chunk_size = min(remaining_bytes,
ibgda_get_lkey_and_rkey(my_laddr = req_lptr,
&my_lkey,
req_rptr,
dst_pe,
&my_raddr,
&my_rkey,
qp->dev_idx));
}
// Move one more message
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
......@@ -352,7 +365,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
EP_DEVICE_ASSERT(num_wqes <= 32);
// Process WQE
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
......@@ -419,7 +431,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment