Unverified Commit 25e1816e authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix custom allreduce performance/accuracy problem (#4477)

parent a53fe428
......@@ -182,8 +182,9 @@ __inline__ __device__ void block_barrier(
}
}
}
__syncthreads();
if constexpr (start || need_fence) {
__syncthreads();
}
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
......@@ -262,6 +263,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
block_barrier<false>(
params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size);
}
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
......@@ -437,24 +440,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);
/*
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
params.elts_per_rank = params.elts_total / params.ranks_per_node;
params.rank_offset = params.local_rank * params.elts_per_rank;
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
......
......@@ -39,7 +39,7 @@ limitations under the License.
namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 32;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment