Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
3571a927
Commit
3571a927
authored
Jul 09, 2025
by
liuhe
Browse files
add DeepEP_multi_port_nobond ibgda support
parent
c50f3d6f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
12 deletions
+31
-12
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+13
-12
third-party/nvshmem.patch
third-party/nvshmem.patch
+18
-0
No files found.
csrc/kernels/ibgda_device.cuh
View file @
3571a927
...
...
@@ -77,7 +77,7 @@ __device__ static __forceinline__
nvshmemi_ibgda_device_qp_t
*
ibgda_get_rc
(
int
pe
,
int
id
)
{
auto
state
=
ibgda_get_state
();
const
auto
num_rc_per_pe
=
ibgda_get_state
()
->
num_rc_per_pe
;
return
&
state
->
globalmem
.
rcs
[
pe
*
num_rc_per_pe
+
id
%
num_rc_per_pe
];
return
&
state
->
globalmem
.
rcs
[
pe
*
num_rc_per_pe
*
state
->
num_devices_initialized
+
id
%
(
num_rc_per_pe
*
state
->
num_devices_initialized
)
];
}
__device__
static
__forceinline__
...
...
@@ -199,20 +199,21 @@ ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *v
__device__
static
__forceinline__
uint64_t
ibgda_get_lkey_and_rkey
(
uint64_t
laddr
,
__be32
*
lkey
,
uint64_t
raddr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
)
{
uint64_t
raddr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
,
uint32_t
dev_idx
)
{
auto
state
=
ibgda_get_state
();
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
auto
log2_cumem_granularity
=
state
->
log2_cumem_granularity
;
// Local key
uint64_t
idx
=
(
laddr
-
heap_start
)
>>
log2_cumem_granularity
;
//printf("device_idx === %u",dev_idx);
uint64_t
idx
=
((
laddr
-
heap_start
)
>>
log2_cumem_granularity
)
*
state
->
num_devices_initialized
+
dev_idx
;
auto
device_key
=
state
->
constmem
.
lkeys
[
idx
];
auto
lchunk_size
=
device_key
.
next_addr
-
laddr
;
*
lkey
=
device_key
.
key
;
// Remote key
uint64_t
roffset
=
raddr
-
heap_start
;
idx
=
((
roffset
>>
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
+
dst_pe
;
idx
=
((
roffset
>>
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
*
state
->
num_devices_initialized
+
dev_idx
+
dst_pe
*
state
->
num_devices_initialized
;
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
{
device_key
=
state
->
constmem
.
rkeys
[
idx
];
}
else
{
...
...
@@ -227,12 +228,12 @@ uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey,
}
__device__
static
__forceinline__
void
ibgda_get_rkey
(
uint64_t
addr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
)
{
ibgda_get_rkey
(
uint64_t
addr
,
int
dst_pe
,
uint64_t
*
out_raddr
,
__be32
*
out_rkey
,
uint32_t
dev_idx
)
{
auto
state
=
ibgda_get_state
();
auto
heap_start
=
reinterpret_cast
<
uint64_t
>
(
nvshmemi_device_state_d
.
heap_base
);
uint64_t
roffset
=
addr
-
heap_start
;
uint64_t
idx
=
((
roffset
>>
state
->
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
)
+
dst_pe
;
uint64_t
idx
=
((
roffset
>>
state
->
log2_cumem_granularity
)
*
nvshmemi_device_state_d
.
npes
*
state
->
num_devices_initialized
)
+
dev_idx
+
dst_pe
*
state
->
num_devices_initialized
;
nvshmemi_ibgda_device_key_t
device_key
;
if
(
idx
<
NVSHMEMI_IBGDA_MAX_CONST_RKEYS
)
device_key
=
state
->
constmem
.
rkeys
[
idx
];
...
...
@@ -261,10 +262,9 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
// NOTES: the `p` operation will not cross multiple remote chunks
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
dst_pe
,
&
raddr
,
&
rkey
);
// Write WQEs
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
dst_pe
,
&
raddr
,
&
rkey
,
qp
->
dev_idx
);
// Write WQEs
uint64_t
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
;
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
base_wqe_idx
);
...
...
@@ -336,11 +336,13 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
uint64_t
my_raddr
=
0
;
uint64_t
my_chunk_size
=
0
;
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
// Decide how many messages (theoretically 3 for maximum)
auto
remaining_bytes
=
bytes
;
while
(
remaining_bytes
>
0
)
{
if
(
lane_id
==
num_wqes
)
my_chunk_size
=
min
(
remaining_bytes
,
ibgda_get_lkey_and_rkey
(
my_laddr
=
req_lptr
,
&
my_lkey
,
req_rptr
,
dst_pe
,
&
my_raddr
,
&
my_rkey
));
my_chunk_size
=
min
(
remaining_bytes
,
ibgda_get_lkey_and_rkey
(
my_laddr
=
req_lptr
,
&
my_lkey
,
req_rptr
,
dst_pe
,
&
my_raddr
,
&
my_rkey
,
qp
->
dev_idx
));
// Move one more message
auto
chunk_size
=
__shfl_sync
(
0xffffffff
,
my_chunk_size
,
static_cast
<
int
>
(
num_wqes
));
...
...
@@ -352,7 +354,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
EP_DEVICE_ASSERT
(
num_wqes
<=
32
);
// Process WQE
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
uint64_t
base_wqe_idx
=
0
;
if
(
lane_id
==
0
)
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
num_wqes
);
...
...
@@ -419,7 +420,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
__be32
rkey
;
uint64_t
raddr
;
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
);
ibgda_get_rkey
(
reinterpret_cast
<
uint64_t
>
(
rptr
),
pe
,
&
raddr
,
&
rkey
,
qp
->
dev_idx
);
uint64_t
my_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
1
);
void
*
wqe_ptrs
=
ibgda_get_wqe_ptr
(
qp
,
my_wqe_idx
);
...
...
third-party/nvshmem.patch
View file @
3571a927
...
...
@@ -470,5 +470,23 @@ index ef325cd..16ee09c 100644
int nvshmemt_ibgda_show_info(struct nvshmem_transport *transport, int style) {
NVSHMEMI_ERROR_PRINT("ibgda show info not implemented");
---
src/host/team/team_internal.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/host/team/team_internal.cpp b/src/host/team/team_internal.cpp
index 8b8a263..1be3dec 100644
--- a/src/host/team/team_internal.cpp
+++ b/src/host/team/team_internal.cpp
@@ -1415,7 +1415,7 @@
CUDA_RUNTIME_CHECK(
cudaMemcpy(device_team_ret_val, team_ret_val, sizeof(int), cudaMemcpyHostToDevice));
CUDA_RUNTIME_CHECK(cudaDeviceSynchronize());
- nvshmemi_call_rdxn_on_stream_kernel<int, RDXN_OPS_MAX>(
+ nvshmemi_reduce_on_stream<int, RDXN_OPS_MAX>(
parent_team->team_idx, device_team_ret_val_reduced, device_team_ret_val, 1,
nvshmemi_state->my_stream);
CUDA_RUNTIME_CHECK(cudaStreamSynchronize(nvshmemi_state->my_stream));
--
2.34.1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment