Commit 24324cc5 authored by Shangyan Zhou's avatar Shangyan Zhou
Browse files

Fix combine tma mbarrier

parent ee5bd170
......@@ -1439,13 +1439,13 @@ combine(int4* combined_x, float* combined_topk_weights,
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerSenderWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes);
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp);
EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= kNumTMABytesPerSenderWarp);
}
__syncwarp();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment