Unverified Commit eb69fad7 authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Fix incorrect TP rank calculation when using data parallel (#2179)


Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
parent 93a67af8
...@@ -607,10 +607,10 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr ...@@ -607,10 +607,10 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr
int comm_bytes_per_rank = comm_bytes / _tp_size; int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
send_stream); _ub_comm, send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank,
recv_stream); _ub_comm, recv_stream);
// We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
for (auto stream : {send_stream, recv_stream}) { for (auto stream : {send_stream, recv_stream}) {
......
...@@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) { int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = 1; j < tp_size; j++) { for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size; int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank; int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank; int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
stream); rank_round_tp + i, stream);
} }
} }
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) { int tp_size, int world_rank, communicator *comm, cudaStream_t stream) {
int rank_round_tp = (world_rank / tp_size) * tp_size;
for (int j = tp_size - 1; j > 0; j--) { for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size; int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i; int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i; int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm,
stream); rank_round_tp + i, stream);
} }
} }
......
...@@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda ...@@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream); int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream); int tp_size, int world_rank, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment