Unverified Commit 0970c4c8 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Fix tma mbarrier (#399)

* Fix mbarrier

* Remove redundant store
parent 174c209f
...@@ -444,7 +444,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -444,7 +444,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// TMA stuffs // TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp; auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + hidden_bytes); auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + num_bytes_per_token);
uint32_t tma_phase = 0; uint32_t tma_phase = 0;
if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and lane_id == 0) { if ((warp_role == WarpRole::kRDMAAndNVLForwarder or warp_role == WarpRole::kNVLReceivers) and lane_id == 0) {
mbarrier_init(tma_mbarrier, 1); mbarrier_init(tma_mbarrier, 1);
...@@ -943,16 +943,17 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv ...@@ -943,16 +943,17 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
} }
__syncwarp(); __syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase); mbarrier_wait(tma_mbarrier, tma_phase);
if (lane_id == 0) if (lane_id == 0) {
tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false); tma_store_1d(tma_buffer, recv_x + recv_token_idx * hidden_int4, hidden_bytes, false);
if (scale_aligned)
tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false);
}
__syncwarp(); __syncwarp();
shifted += hidden_bytes; shifted += hidden_bytes;
// Copy scales // Copy scales
// TODO: make it as templated // TODO: make it as templated
if (scale_aligned) { if (not scale_aligned) {
tma_store_1d(tma_buffer + hidden_bytes, recv_x_scales + recv_token_idx * num_scales, scale_bytes, false);
} else {
UNROLLED_WARP_COPY(1, lane_id, num_scales, UNROLLED_WARP_COPY(1, lane_id, num_scales,
recv_x_scales + recv_token_idx * num_scales, recv_x_scales + recv_token_idx * num_scales,
reinterpret_cast<float*>(shifted), reinterpret_cast<float*>(shifted),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment