Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
0a47402f
Commit
0a47402f
authored
Jul 02, 2025
by
Chenggang Zhao
Browse files
Code cleanup
parent
b6516358
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
csrc/kernels/internode.cu
csrc/kernels/internode.cu
+4
-2
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+4
-4
No files found.
csrc/kernels/internode.cu
View file @
0a47402f
...
...
@@ -138,12 +138,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
ld_volatile_global
,
st_na_global
);
}
}
__syncthreads
();
// Wait previous operations to be finished
if
(
thread_id
<
kNumRDMARanks
and
thread_id
!=
rdma_rank
)
nvshmemi_ibgda_quiet
(
translate_dst_rdma_rank
<
kLowLatencyMode
>
(
thread_id
,
nvl_rank
),
0
);
__syncthreads
();
// Barrier
if
(
thread_id
==
0
)
nvshmem_sync_with_same_gpu_idx
<
kLowLatencyMode
>
(
rdma_team
);
__syncthreads
();
...
...
csrc/kernels/internode_ll.cu
View file @
0a47402f
...
...
@@ -499,9 +499,9 @@ combine(void* combined_x,
cg
::
this_grid
().
sync
();
// Reduce tokens
EP_DEVICE_ASSERT
(
num_topk
<=
32
and
hidden_bf16_int4
<=
1024
);
EP_DEVICE_ASSERT
(
num_topk
<=
32
);
EP_STATIC_ASSERT
(
kHidden
%
(
32
*
kNumElemsPerInt4
)
==
0
,
"Invalid vectorization"
);
for
(
int
k
=
thread_id
;
k
<
hidden_bf16_int4
;
k
+=
num_threads
)
{
for
(
int
hidden_idx
=
thread_id
;
hidden_idx
<
hidden_bf16_int4
;
hidden_idx
+=
num_threads
)
{
for
(
int
token_idx
=
sm_id
;
token_idx
<
num_combined_tokens
;
token_idx
+=
num_sms
)
{
// Read top-k indices and weights
int
reg_topk_idx
[
kNumMaxTopk
];
...
...
@@ -520,7 +520,7 @@ combine(void* combined_x,
auto
rdma_buffer_row
=
reinterpret_cast
<
const
uint8_t
*>
(
rdma_buffer_type
);
// Reduce
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
k
);
auto
x_vec
=
ld_nc_global
(
reinterpret_cast
<
const
int4
*>
(
rdma_buffer_row
)
+
hidden_idx
);
const
auto
x_bf16
=
reinterpret_cast
<
nv_bfloat16
*>
(
&
x_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
...
...
@@ -533,7 +533,7 @@ combine(void* combined_x,
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerInt4
;
++
j
)
combined_bf16
[
j
]
=
static_cast
<
nv_bfloat16
>
(
combined_values
[
j
]);
(
static_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
k
]
=
combined_int4
;
(
static_cast
<
int4
*>
(
combined_x
)
+
token_idx
*
hidden_bf16_int4
)[
hidden_idx
]
=
combined_int4
;
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment