Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
898269fa
Unverified
Commit
898269fa
authored
Jul 11, 2025
by
Shangyan Zhou
Committed by
GitHub
Jul 11, 2025
Browse files
ibgda: support non-bond dual-port environments
ibgda: support non-bond dual-port environments via multi-port config
parents
b0f13ef7
1cd5eea6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
13 deletions
+25
-13
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+25
-13
No files found.
csrc/kernels/ibgda_device.cuh
View file @
898269fa
...
@@ -77,7 +77,7 @@ __device__ static __forceinline__
...
@@ -77,7 +77,7 @@ __device__ static __forceinline__
nvshmemi_ibgda_device_qp_t
*
ibgda_get_rc
(
int
pe
,
int
id
)
{
nvshmemi_ibgda_device_qp_t
*
ibgda_get_rc
(
int
pe
,
int
id
)
{
auto
state
=
ibgda_get_state
();
auto
state
=
ibgda_get_state
();
const
auto
num_rc_per_pe
=
ibgda_get_state
()
->
num_rc_per_pe
;
const
auto
num_rc_per_pe
=
ibgda_get_state
()
->
num_rc_per_pe
;
return
&
state
->
globalmem
.
rcs
[
pe
*
num_rc_per_pe
+
id
%
num_rc_per_pe
];
return
&
state
->
globalmem
.
rcs
[
pe
*
num_rc_per_pe
*
state
->
num_devices_initialized
+
id
%
(
num_rc_per_pe
*
state
->
num_devices_initialized
)
];
}
}
__device__
static
__forceinline__
__device__
static
__forceinline__
...
@@ -199,20 +199,22 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v
...
@@ -199,20 +199,22 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v
__device__
static
__forceinline__
__device__
static
__forceinline__
uint64_t
ibgda_get_lkey_and_rkey
(
uint64_t
laddr
,
__be32
*
lkey
,
uint64_t
ibgda_get_lkey_and_rkey
(
uint64_t
laddr
,
__be32
*
lkey
,
uint64_t
raddr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
)
{
uint64_t
raddr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
,
uint32_t
dev_idx
)
{
auto
state
=
ibgda_get_state
();
auto
state
=
ibgda_get_state
();
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
auto
log2_cumem_granularity
=
state
->
log2_cumem_granularity
;
auto
log2_cumem_granularity
=
state
->
log2_cumem_granularity
;
// Local key
// Local key
uint64_t
idx
=
(
laddr
-
heap_start
)
>>
log2_cumem_granularity
;
uint64_t
idx
=
(
(
laddr
-
heap_start
)
>>
log2_cumem_granularity
)
*
state
->
num_devices_initialized
+
dev_idx
;
auto
device_key
=
state
->
constmem
.
lkeys
[
idx
];
auto
device_key
=
state
->
constmem
.
lkeys
[
idx
];
auto
lchunk_size
=
device_key
.
next_addr
-
laddr
;
auto
lchunk_size
=
device_key
.
next_addr
-
laddr
;
*
lkey
=
device_key
.
key
;
*
lkey
=
device_key
.
key
;
// Remote key
// Remote key
uint64_t
roffset
=
raddr
-
heap_start
;
uint64_t
roffset
=
raddr
-
heap_start
;
idx
=
((
roffset
>>
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
+
dst_pe
;
idx
=
((
roffset
>>
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
*
state
->
num_devices_initialized
+
dst_pe
*
state
->
num_devices_initialized
+
dev_idx
;
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
{
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
{
device_key
=
state
->
constmem
.
rkeys
[
idx
];
device_key
=
state
->
constmem
.
rkeys
[
idx
];
}
else
{
}
else
{
...
@@ -227,12 +229,13 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
...
@@ -227,12 +229,13 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
}
}
__device__
static
__forceinline__
void
__device__
static
__forceinline__
void
ibgda_get_rkey
(
uint64_t
addr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
)
{
ibgda_get_rkey
(
uint64_t
addr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
,
uint32_t
dev_idx
)
{
auto
state
=
ibgda_get_state
();
auto
state
=
ibgda_get_state
();
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
uint64_t
roffset
=
addr
-
heap_start
;
uint64_t
roffset
=
addr
-
heap_start
;
uint64_t
idx
=
((
roffset
>>
state
->
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
+
dst_pe
;
uint64_t
idx
=
((
roffset
>>
state
->
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
*
state
->
num_devices_initialized
)
+
dst_pe
*
state
->
num_devices_initialized
+
dev_idx
;
nvshmemi_ibgda_device_key_t
device_key
;
nvshmemi_ibgda_device_key_t
device_key
;
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
device_key
=
state
->
constmem
.
rkeys
[
idx
];
device_key
=
state
->
constmem
.
rkeys
[
idx
];
...
@@ -261,10 +264,10 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
...
@@ -261,10 +264,10 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
// NOTES: the `p` operation will not cross multiple remote chunks
// NOTES: the `p` operation will not cross multiple remote chunks
__be32
rkey
;
__be32
rkey
;
uint64_t
raddr
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
dst_pe
,
&
raddr
,
&
rkey
);
// Write WQEs
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
dst_pe
,
&
raddr
,
&
rkey
,
qp
->
dev_idx
);
// Write WQEs
uint64_t
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
uint64_t
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
;
void
*
wqe_ptrs
;
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
base_wqe_idx
);
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
base_wqe_idx
);
...
@@ -336,11 +339,21 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -336,11 +339,21 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
uint64_t
my_raddr
=
0
;
uint64_t
my_raddr
=
0
;
uint64_t
my_chunk_size
=
0
;
uint64_t
my_chunk_size
=
0
;
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
// Decide how many messages (theoretically 3 for maximum)
// Decide how many messages (theoretically 3 for maximum)
auto
remaining_bytes
=
bytes
;
auto
remaining_bytes
=
bytes
;
while
(
remaining_bytes
>
0
)
{
while
(
remaining_bytes
>
0
)
{
if
(
lane_id
==
num_wqes
)
if
(
lane_id
==
num_wqes
)
{
my_chunk_size
=
min
(
remaining_bytes
,
ibgda_get_lkey_and_rkey
(
my_laddr
=
req_lptr
,
&
my_lkey
,
req_rptr
,
dst_pe
,
&
my_raddr
,
&
my_rkey
));
my_chunk_size
=
min
(
remaining_bytes
,
ibgda_get_lkey_and_rkey
(
my_laddr
=
req_lptr
,
&
my_lkey
,
req_rptr
,
dst_pe
,
&
my_raddr
,
&
my_rkey
,
qp
->
dev_idx
));
}
// Move one more message
// Move one more message
auto
chunk_size
=
__shfl_sync
(
0xffffffff
,
my_chunk_size
,
static_cast
<
int
>
(
num_wqes
));
auto
chunk_size
=
__shfl_sync
(
0xffffffff
,
my_chunk_size
,
static_cast
<
int
>
(
num_wqes
));
...
@@ -352,7 +365,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
...
@@ -352,7 +365,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
EP_DEVICE_ASSERT
(
num_wqes
<=
32
);
EP_DEVICE_ASSERT
(
num_wqes
<=
32
);
// Process WQE
// Process WQE
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
uint64_t
base_wqe_idx
=
0
;
uint64_t
base_wqe_idx
=
0
;
if
(
lane_id
==
0
)
if
(
lane_id
==
0
)
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
num_wqes
);
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
num_wqes
);
...
@@ -419,7 +431,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
...
@@ -419,7 +431,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
__be32
rkey
;
__be32
rkey
;
uint64_t
raddr
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
,
qp
->
dev_idx
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment