Unverified Commit eb69fad7 authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Fix incorrect TP rank calculation when using data parallel (#2179)


Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
parent 93a67af8
......@@ -607,10 +607,10 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr
int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
_ub_comm, send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
_ub_comm, recv_stream);
// We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
for (auto stream : {send_stream, recv_stream}) {
......
......@@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
rank_round_tp + i, stream);
}
}
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
rank_round_tp + i, stream);
}
}
......
......@@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment