Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
e2c57848
Commit
e2c57848
authored
Apr 21, 2025
by
Shangyan Zhou
Browse files
Revert `ibgda_device.cuh` and remove some comments.
parent
5ab80c28
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
365 deletions
+4
-365
csrc/kernels/ibgda_device.cuh
csrc/kernels/ibgda_device.cuh
+2
-349
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+2
-16
No files found.
csrc/kernels/ibgda_device.cuh
View file @
e2c57848
...
...
@@ -11,10 +11,6 @@
#include "exception.cuh"
#include "utils.cuh"
// #define NVSHMEM_TIMEOUT_DEVICE_POLLING
// #define IBGDA_POLL_TIMEOUT 4000000000LLU
// #define NVSHMEM_IBGDA_DEBUG
namespace
deep_ep
{
EP_STATIC_ASSERT
(
NVSHMEMI_IBGDA_MIN_QP_DEPTH
>=
64
,
"Invalid QP minimum depth"
);
...
...
@@ -246,353 +242,15 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey)
*
out_rkey
=
device_key
.
key
;
}
#ifndef likely
#define likely(x) (__builtin_expect(!!(x), 1))
#endif
#ifndef unlikely
#define unlikely(x) (__builtin_expect(!!(x), 0))
#endif
#ifndef ACCESS_ONCE
#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x))
#endif
/**
* DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug.
* BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE.
*/
#ifndef READ_ONCE
#define READ_ONCE(x) ACCESS_ONCE(x)
#endif
#ifndef WRITE_ONCE
#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v))
#endif
#ifdef NVSHMEM_IBGDA_DEBUG
struct
mlx5_err_cqe_ex
{
uint8_t
rsvd0
[
32
];
__be32
srqn
;
uint8_t
rsvd1
[
16
];
uint8_t
hw_err_synd
;
uint8_t
hw_synd_type
;
uint8_t
vendor_err_synd
;
uint8_t
syndrome
;
__be32
s_wqe_opcode_qpn
;
__be16
wqe_counter
;
uint8_t
signature
;
uint8_t
op_own
;
};
typedef
struct
mlx5_err_cqe_ex
ibgda_mlx5_err_cqe_t
;
#else
typedef
struct
mlx5_err_cqe
ibgda_mlx5_err_cqe_t
;
#endif
__device__
static
inline
uint16_t
BSWAP16
(
uint16_t
x
)
{
uint16_t
ret
;
uint32_t
a
=
(
uint32_t
)
x
;
uint32_t
d
;
asm
volatile
(
"{
\n\t
"
".reg .b32 mask;
\n\t
"
".reg .b32 ign;
\n\t
"
"mov.b32 mask, 0x4401;
\n\t
"
"mov.b32 ign, 0x0;
\n\t
"
"prmt.b32 %0, %1, ign, mask;
\n\t
"
"}"
:
"=r"
(
d
)
:
"r"
(
a
));
ret
=
(
uint16_t
)
d
;
return
ret
;
}
/**
* DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug.
* See the comment near READ_ONCE.
*/
__device__
static
inline
uint8_t
ibgda_atomic_read
(
uint8_t
*
ptr
)
{
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint16_t
ret
;
asm
volatile
(
"ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];"
:
"=h"
(
ret
)
:
"l"
(
ptr
));
return
(
uint8_t
)
ret
;
#else
return
READ_ONCE
(
*
ptr
);
#endif
}
__device__
static
inline
uint16_t
ibgda_atomic_read
(
uint16_t
*
ptr
)
{
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint16_t
ret
;
asm
volatile
(
"ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];"
:
"=h"
(
ret
)
:
"l"
(
ptr
));
return
ret
;
#else
return
READ_ONCE
(
*
ptr
);
#endif
}
__device__
static
inline
uint32_t
ibgda_atomic_read
(
uint32_t
*
ptr
)
{
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint32_t
ret
;
asm
volatile
(
"ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];"
:
"=r"
(
ret
)
:
"l"
(
ptr
));
return
ret
;
#else
return
READ_ONCE
(
*
ptr
);
#endif
}
__device__
static
inline
uint64_t
ibgda_atomic_read
(
uint64_t
*
ptr
)
{
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint64_t
ret
;
asm
volatile
(
"ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];"
:
"=l"
(
ret
)
:
"l"
(
ptr
));
return
ret
;
#else
return
READ_ONCE
(
*
ptr
);
#endif
}
// Prevent code reordering from both compiler and GPU
__device__
static
inline
void
IBGDA_MFENCE
()
{
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE
asm
volatile
(
"fence.acq_rel.cta;"
:::
"memory"
);
#else
__threadfence_block
();
#endif
/* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */
}
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
__device__
static
inline
uint64_t
ibgda_query_globaltimer
()
{
uint64_t
ret
;
asm
volatile
(
"mov.u64 %0, %%globaltimer;"
:
"=l"
(
ret
)
::
"memory"
);
return
ret
;
}
#endif
/* NVSHMEM_TIMEOUT_DEVICE_POLLING */
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
__device__
static
inline
int
ibgda_check_poll_timeout
(
nvshmemi_ibgda_device_cq_t
*
cq
,
uint64_t
now
,
uint64_t
start
,
uint64_t
idx
,
int
*
error
)
{
int
status
=
0
;
struct
mlx5_cqe64
*
cqe64
=
(
struct
mlx5_cqe64
*
)
cq
->
cqe
;
uint8_t
opown
;
uint8_t
opcode
;
uint16_t
wqe_counter
;
if
(
unlikely
(
now
-
start
>
IBGDA_POLL_TIMEOUT
))
{
*
error
=
-
ETIME
;
opown
=
ibgda_atomic_read
(
&
cqe64
->
op_own
);
opcode
=
opown
>>
4
;
wqe_counter
=
ibgda_atomic_read
(
&
cqe64
->
wqe_counter
);
wqe_counter
=
BSWAP16
(
wqe_counter
);
printf
(
"[%d] ibgda_poll_cq timeout:
\n
"
" cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x
\n
"
" wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx
\n
"
" while waiting for idx=%#lx.
\n
"
,
nvshmemi_device_state_d
.
mype
,
ibgda_atomic_read
(
cq
->
cons_idx
),
ibgda_atomic_read
(
cq
->
prod_idx
),
cq
->
cqn
,
cq
->
qpn
,
opcode
,
wqe_counter
,
ibgda_atomic_read
(
cq
->
resv_head
),
ibgda_atomic_read
(
cq
->
ready_head
),
idx
);
status
=
-
1
;
}
return
status
;
}
#endif
__device__
static
inline
int
ibgda_poll_cq
(
nvshmemi_ibgda_device_cq_t
*
cq
,
uint64_t
idx
,
int
*
error
)
{
int
status
=
0
;
struct
mlx5_cqe64
*
cqe64
=
(
struct
mlx5_cqe64
*
)
cq
->
cqe
;
const
uint32_t
ncqes
=
cq
->
ncqes
;
uint8_t
opown
;
uint8_t
opcode
;
uint16_t
wqe_counter
;
uint16_t
new_wqe_counter
;
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
uint64_t
start
=
ibgda_query_globaltimer
();
uint64_t
now
;
#endif
uint64_t
cons_idx
=
ibgda_atomic_read
(
cq
->
cons_idx
);
uint64_t
new_cons_idx
;
assert
(
likely
(
cq
->
qp_type
==
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI
||
cq
->
qp_type
==
NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC
));
if
(
unlikely
(
cons_idx
>=
idx
))
goto
out
;
#ifdef NVSHMEM_IBGDA_DEBUG
// We can skip opcode == MLX5_CQE_INVALID check because we have already
// initialized the CQ buffer to 0xff. With the QP depth range we enforce,
// cons_idx cannot progress unless wqe_counter read from the CQ buffer is
// a valid value.
do
{
opown
=
ibgda_atomic_read
(
&
cqe64
->
op_own
);
opcode
=
opown
>>
4
;
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
// TODO: Integrate timeout handler with the core NVSHMEM
now
=
ibgda_query_globaltimer
();
status
=
ibgda_check_poll_timeout
(
cq
,
now
,
start
,
idx
,
error
);
if
(
status
!=
0
)
goto
check_opcode
;
#endif
/* NVSHMEM_TIMEOUT_DEVICE_POLLING */
}
while
(
unlikely
(
opcode
==
MLX5_CQE_INVALID
));
// Prevent reordering of the opcode wait above
IBGDA_MFENCE
();
#endif
/* NVSHMEM_IBGDA_DEBUG */
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
start
=
ibgda_query_globaltimer
();
#endif
// If idx is a lot greater than cons_idx, we might get incorrect result due
// to wqe_counter wraparound. We need to check prod_idx to be sure that idx
// has already been submitted.
while
(
unlikely
(
ibgda_atomic_read
(
cq
->
prod_idx
)
<
idx
))
;
IBGDA_MFENCE
();
do
{
new_wqe_counter
=
ibgda_atomic_read
(
&
cqe64
->
wqe_counter
);
new_wqe_counter
=
BSWAP16
(
new_wqe_counter
);
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
now
=
ibgda_query_globaltimer
();
status
=
ibgda_check_poll_timeout
(
cq
,
now
,
start
,
idx
,
error
);
if
(
status
!=
0
)
goto
check_opcode
;
// Observe progress. Reset the timer.
if
(
new_wqe_counter
!=
wqe_counter
)
start
=
now
;
#endif
wqe_counter
=
new_wqe_counter
;
// Another thread may have updated cons_idx.
cons_idx
=
ibgda_atomic_read
(
cq
->
cons_idx
);
if
(
likely
(
cons_idx
>=
idx
))
goto
out
;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
while
(
unlikely
(((
uint16_t
)((
uint16_t
)
idx
-
wqe_counter
-
(
uint16_t
)
2
)
<
ncqes
)));
// new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the
// MSB from idx. We also need to take care of wraparound.
++
wqe_counter
;
new_cons_idx
=
(
idx
&
~
(
0xffffULL
)
|
wqe_counter
)
+
(((
uint16_t
)
idx
>
wqe_counter
)
?
0x10000ULL
:
0x0
);
atomicMax
((
unsigned
long
long
int
*
)
cq
->
cons_idx
,
(
unsigned
long
long
int
)
new_cons_idx
);
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
check_opcode:
#endif
// NVSHMEM always treats CQE errors as fatal.
// Even if this error doesn't belong to the CQE in cons_idx,
// we will just report and terminate the process.
opown
=
ibgda_atomic_read
(
&
cqe64
->
op_own
);
opcode
=
opown
>>
4
;
if
(
unlikely
(
opcode
==
MLX5_CQE_REQ_ERR
))
{
ibgda_mlx5_err_cqe_t
*
cqe_err
=
(
ibgda_mlx5_err_cqe_t
*
)
cqe64
;
*
error
=
cqe_err
->
syndrome
;
#ifdef NVSHMEM_IBGDA_DEBUG
__be16
wqe_counter
=
ibgda_atomic_read
(
&
cqe_err
->
wqe_counter
);
__be32
s_wqe_opcode_qpn
=
ibgda_atomic_read
(
&
cqe_err
->
s_wqe_opcode_qpn
);
printf
(
"[%d] got completion with err:
\n
"
" syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,
\n
"
" wqe_counter=%#x, s_wqe_opcode_qpn=%#x,
\n
"
" cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx
\n
"
,
nvshmemi_device_state_d
.
mype
,
cqe_err
->
syndrome
,
cqe_err
->
vendor_err_synd
,
cqe_err
->
hw_err_synd
,
cqe_err
->
hw_synd_type
,
BSWAP16
(
wqe_counter
),
BSWAP32
(
s_wqe_opcode_qpn
),
cq
->
cqn
,
cons_idx
,
ibgda_atomic_read
(
cq
->
prod_idx
),
idx
);
#endif
/* NVSHMEM_IBGDA_DEBUG */
status
=
-
1
;
}
out:
// Prevent reordering of this function and subsequent instructions
IBGDA_MFENCE
();
return
status
;
}
__device__
static
inline
uint64_t
ibgda_quiet
(
nvshmemi_ibgda_device_qp_t
*
qp
)
{
nvshmemi_ibgda_device_state_t
*
state
=
ibgda_get_state
();
uint64_t
prod_idx
=
state
->
use_async_postsend
?
ibgda_atomic_read
(
qp
->
tx_wq
.
prod_idx
)
:
ibgda_atomic_read
(
&
qp
->
mvars
.
tx_wq
.
ready_head
);
nvshmemi_ibgda_device_cq_t
cq
=
*
qp
->
tx_wq
.
cq
;
int
err
=
0
;
int
status
=
ibgda_poll_cq
(
&
cq
,
prod_idx
,
&
err
);
// TODO: Integrate the error handler with the core NVSHMEM
#ifdef NVSHMEM_IBGDA_DEBUG
if
(
status
)
{
printf
(
"ibgda_poll_cq failed with error=%d.
\n
"
,
err
);
}
#endif
assert
(
likely
(
status
==
0
));
return
prod_idx
;
}
__device__
static
inline
void
ibgda_wait_for_slot_availability
(
nvshmemi_ibgda_device_qp_t
*
qp
,
uint64_t
wqe_idx
)
{
int
status
=
0
;
int
err
=
0
;
uint16_t
nwqes
=
qp
->
tx_wq
.
nwqes
;
nwqes
=
nwqes
/
2
;
// We don't want wqe_idx - nwqes to wraparound.
if
(
likely
(
wqe_idx
>=
nwqes
))
{
nvshmemi_ibgda_device_cq_t
cq
=
*
qp
->
tx_wq
.
cq
;
status
=
ibgda_poll_cq
(
&
cq
,
wqe_idx
-
nwqes
,
&
err
);
// TODO: Integrate the error handler with the core NVSHMEM
if
(
status
)
{
printf
(
"ibgda_poll_cq failed with error=%d.
\n
"
,
err
);
}
assert
(
likely
(
status
==
0
));
}
IBGDA_MFENCE
();
}
template
<
bool
nbi
=
true
>
__device__
static
__forceinline__
uint64_t
ibgda_reserve_wqe_slots
(
nvshmemi_ibgda_device_qp_t
*
qp
,
uint32_t
num_wqes
)
{
auto
mvars
=
&
qp
->
mvars
;
uint64_t
wqe_idx
;
wqe_idx
=
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
&
mvars
->
tx_wq
.
resv_head
),
static_cast
<
unsigned
long
long
>
(
num_wqes
));
if
(
!
nbi
)
{
uint64_t
prod_idx
=
mvars
->
tx_wq
.
prod_idx
;
uint64_t
cons_idx
=
mvars
->
tx_wq
.
cons_idx
;
uint64_t
delta
=
prod_idx
-
cons_idx
;
uint64_t
cnt
=
qp
->
tx_wq
.
nwqes
;
if
(
delta
>
cnt
)
{
printf
(
"prod_idx: %lu
\t
cons_idx: %lu
\t
cnt: %lu
\t
delta: %lu
\n
"
,
prod_idx
,
cons_idx
,
cnt
,
delta
);
EP_DEVICE_ASSERT
(
delta
<=
cnt
);
}
// If last slot is available, all prior slots are also available.
ibgda_wait_for_slot_availability
(
qp
,
wqe_idx
+
num_wqes
);
}
// return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
return
wqe_idx
;
return
atomicAdd
(
reinterpret_cast
<
unsigned
long
long
*>
(
&
mvars
->
tx_wq
.
resv_head
),
static_cast
<
unsigned
long
long
>
(
num_wqes
));
}
__device__
static
__forceinline__
void
*
ibgda_get_wqe_ptr
(
nvshmemi_ibgda_device_qp_t
*
qp
,
uint16_t
wqe_idx
)
{
uint16_t
cnt
=
qp
->
tx_wq
.
nwqes
;
EP_DEVICE_ASSERT
(
cnt
!=
0
);
uint16_t
idx
=
wqe_idx
&
(
cnt
-
1
);
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
qp
->
tx_wq
.
wqe
)
+
(
idx
<<
MLX5_SEND_WQE_SHIFT
));
}
...
...
@@ -667,7 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed
(
reinterpret_cast
<
int4
*>
(
data_seg_ptr
),
*
reinterpret_cast
<
const
int4
*>
(
&
data_seg
));
}
template
<
bool
nbi
=
true
>
__device__
static
__forceinline__
void
nvshmemi_ibgda_put_nbi_warp
(
uint64_t
req_rptr
,
uint64_t
req_lptr
,
size_t
bytes
,
int
dst_pe
,
int
qp_id
,
int
lane_id
,
int
message_idx
)
{
// Get lkey and rkey, store them into lanes
...
...
@@ -697,7 +354,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
auto
qp
=
ibgda_get_rc
(
dst_pe
,
qp_id
);
uint64_t
base_wqe_idx
=
0
;
if
(
lane_id
==
0
)
base_wqe_idx
=
ibgda_reserve_wqe_slots
<
nbi
>
(
qp
,
num_wqes
);
base_wqe_idx
=
ibgda_reserve_wqe_slots
(
qp
,
num_wqes
);
base_wqe_idx
=
__shfl_sync
(
0xffffffff
,
base_wqe_idx
,
0
);
if
(
lane_id
<
num_wqes
)
{
auto
wqe_ptr
=
ibgda_get_wqe_ptr
(
qp
,
base_wqe_idx
+
lane_id
);
...
...
@@ -710,10 +367,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
if
(
lane_id
==
0
)
ibgda_submit_requests
<
false
>
(
qp
,
base_wqe_idx
,
num_wqes
,
message_idx
);
__syncwarp
();
// if (!nbi) {
// ibgda_quiet(qp);
// }
}
__device__
static
__forceinline__
void
ibgda_write_amo_add_wqe
(
...
...
csrc/kernels/internode.cu
View file @
e2c57848
...
...
@@ -711,14 +711,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
dst_slot_idx
=
synced_last_issued_tail
%
num_max_rdma_chunked_recv_tokens
;
EP_DEVICE_ASSERT
(
dst_slot_idx
+
num_tokens_to_issue
<=
num_max_rdma_chunked_recv_tokens
);
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
// rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
// num_bytes_per_rdma_token * num_tokens_to_issue,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
const
size_t
num_bytes_per_msg
=
(
num_bytes_per_rdma_token
*
num_tokens_to_issue
)
*
sizeof
(
int8_t
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
dst_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
<
false
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
3
);
nvshmem_fence
();
}
else
{
...
...
@@ -731,8 +727,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if
(
lane_id
==
dst_rdma_rank
)
{
last_issued_tail
+=
num_tokens_to_issue
;
num_tokens_to_send
-=
num_tokens_to_issue
;
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_tokens_to_issue
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
);
...
...
@@ -939,8 +933,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if
(
min_head
!=
std
::
numeric_limits
<
int
>::
max
()
and
min_head
>=
last_head
+
num_max_rdma_chunked_send_tokens
and
lane_id
<
kNumRDMARanks
)
{
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
if
(
lane_id
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_head
.
buffer
(
rdma_rank
),
min_head
-
last_head
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
lane_id
,
nvl_rank
),
channel_id
);
...
...
@@ -1578,14 +1570,10 @@ combine(int4* combined_x, float* combined_topk_weights,
if
(
sub_warp_id
==
kNumWarpsPerForwarder
-
1
)
{
if
(
dst_rdma_rank
!=
rdma_rank
)
{
auto
rdma_slot_idx
=
token_start_idx
%
num_max_rdma_chunked_recv_tokens
;
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
// rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
// num_chunked_tokens * num_bytes_per_rdma_token,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
const
size_t
num_bytes_per_msg
=
(
num_chunked_tokens
*
num_bytes_per_rdma_token
)
*
sizeof
(
int8_t
);
const
auto
dst_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
recv_buffer
(
rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
const
auto
src_ptr
=
reinterpret_cast
<
uint64_t
>
(
rdma_channel_data
.
send_buffer
(
dst_rdma_rank
)
+
rdma_slot_idx
*
num_bytes_per_rdma_token
);
nvshmemi_ibgda_put_nbi_warp
<
false
>
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
nvshmemi_ibgda_put_nbi_warp
(
dst_ptr
,
src_ptr
,
num_bytes_per_msg
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
,
lane_id
,
3
);
nvshmem_fence
();
}
else
{
...
...
@@ -1595,8 +1583,6 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp
();
if
(
lane_id
==
0
)
{
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if
(
dst_rdma_rank
!=
rdma_rank
)
{
nvshmemi_ibgda_amo_nonfetch_add
(
rdma_channel_tail
.
buffer
(
rdma_rank
),
num_chunked_tokens
,
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
dst_rdma_rank
,
nvl_rank
),
channel_id
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment