Commit 1cd5eea6 authored by root's avatar root Committed by liuhe
Browse files

Increase MAX_NUM_HCAS from 16 to 32 to support more NICs in NVSHMEM

fix format
parent 3571a927
...@@ -199,13 +199,12 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v ...@@ -199,13 +199,12 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v
__device__ static __forceinline__ __device__ static __forceinline__
uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey, uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey ,uint32_t dev_idx) { uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
auto state = ibgda_get_state(); auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base); auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
auto log2_cumem_granularity = state->log2_cumem_granularity; auto log2_cumem_granularity = state->log2_cumem_granularity;
// Local key // Local key
//printf("device_idx === %u",dev_idx);
uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx; uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * state->num_devices_initialized + dev_idx;
auto device_key = state->constmem.lkeys[idx]; auto device_key = state->constmem.lkeys[idx];
auto lchunk_size = device_key.next_addr - laddr; auto lchunk_size = device_key.next_addr - laddr;
...@@ -213,7 +212,9 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey, ...@@ -213,7 +212,9 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
// Remote key // Remote key
uint64_t roffset = raddr - heap_start; uint64_t roffset = raddr - heap_start;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes)*state->num_devices_initialized + dev_idx + dst_pe * state->num_devices_initialized;
idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * state->num_devices_initialized
+ dst_pe * state->num_devices_initialized + dev_idx;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) {
device_key = state->constmem.rkeys[idx]; device_key = state->constmem.rkeys[idx];
} else { } else {
...@@ -228,12 +229,13 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey, ...@@ -228,12 +229,13 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
} }
__device__ static __forceinline__ void __device__ static __forceinline__ void
ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey,uint32_t dev_idx) { ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey, uint32_t dev_idx) {
auto state = ibgda_get_state(); auto state = ibgda_get_state();
auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base); auto heap_start = reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base);
uint64_t roffset = addr - heap_start; uint64_t roffset = addr - heap_start;
uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized) + dev_idx + dst_pe * state->num_devices_initialized ; uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes * state->num_devices_initialized)
+ dst_pe * state->num_devices_initialized + dev_idx;
nvshmemi_ibgda_device_key_t device_key; nvshmemi_ibgda_device_key_t device_key;
if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS)
device_key = state->constmem.rkeys[idx]; device_key = state->constmem.rkeys[idx];
...@@ -263,7 +265,8 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t ...@@ -263,7 +265,8 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
__be32 rkey; __be32 rkey;
uint64_t raddr; uint64_t raddr;
auto qp = ibgda_get_rc(dst_pe, qp_id); auto qp = ibgda_get_rc(dst_pe, qp_id);
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey,qp->dev_idx); ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx);
// Write WQEs // Write WQEs
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs; void *wqe_ptrs;
...@@ -341,8 +344,16 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, ...@@ -341,8 +344,16 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
// Decide how many messages (theoretically 3 for maximum) // Decide how many messages (theoretically 3 for maximum)
auto remaining_bytes = bytes; auto remaining_bytes = bytes;
while (remaining_bytes > 0) { while (remaining_bytes > 0) {
if (lane_id == num_wqes) if (lane_id == num_wqes) {
my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey , qp->dev_idx)); my_chunk_size = min(remaining_bytes,
ibgda_get_lkey_and_rkey(my_laddr = req_lptr,
&my_lkey,
req_rptr,
dst_pe,
&my_raddr,
&my_rkey,
qp->dev_idx));
}
// Move one more message // Move one more message
auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes)); auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast<int>(num_wqes));
...@@ -420,7 +431,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons ...@@ -420,7 +431,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
__be32 rkey; __be32 rkey;
uint64_t raddr; uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey,qp->dev_idx); ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey, qp->dev_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
......
...@@ -470,23 +470,5 @@ index ef325cd..16ee09c 100644 ...@@ -470,23 +470,5 @@ index ef325cd..16ee09c 100644
int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) { int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) {
NVSHMEMI_ERROR_PRINT("ibgda show info not implemented"); NVSHMEMI_ERROR_PRINT("ibgda show info not implemented");
---
src/host/team/team_internal.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/host/team/team_internal.cpp b/src/host/team/team_internal.cpp
index 8b8a263..1be3dec 100644
--- a/src/host/team/team_internal.cpp
+++ b/src/host/team/team_internal.cpp
@@ -1415,7 +1415,7 @@
CUDA_RUNTIME_CHECK(
cudaMemcpy(device_team_ret_val, team_ret_val, sizeof(int), cudaMemcpyHostToDevice));
CUDA_RUNTIME_CHECK(cudaDeviceSynchronize());
- nvshmemi_call_rdxn_on_stream_kernel<int, RDXN_OPS_MAX>(
+ nvshmemi_reduce_on_stream<int, RDXN_OPS_MAX>(
parent_team->team_idx, device_team_ret_val_reduced, device_team_ret_val, 1,
nvshmemi_state->my_stream);
CUDA_RUNTIME_CHECK(cudaStreamSynchronize(nvshmemi_state->my_stream));
-- --
2.34.1 2.34.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment