"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "4038542b157d8e7ffd5c6f47166d9544fea03800"
Unverified Commit 201279fa authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Use separate streams for pushsend/recv kernels in UB p2p exchanges (#188)



* using different strems for pushsend and pushrecv
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix stream dependency
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* add wait from main_stream to memcpy stream
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 25bb8647
......@@ -332,9 +332,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
void *_ubuf_ptr;
torch::Tensor _ubuf;
std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _start_accum, _stop_accum;
cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2,
int num_max_streams) {
......@@ -385,10 +386,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
cudaEventCreateWithFlags(&_start_accum, 0);
cudaEventCreateWithFlags(&_stop_accum, 0);
cudaEventCreateWithFlags(&_stop_send, 0);
cudaEventCreateWithFlags(&_stop_recv, 0);
}
/*
......@@ -430,7 +429,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
assert(pre_gelu_out.numel() == 0);
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
......@@ -442,11 +442,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_comm);
(cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
(cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
......@@ -476,18 +477,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_comm);
next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
prev_rank, (cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0));
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
......@@ -497,7 +501,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) {
......@@ -524,18 +529,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_comm);
_next_rank, (cudaStream_t)_stream_send);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
_prev_rank, (cudaStream_t)_stream_recv);
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0));
CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0));
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
......@@ -544,7 +552,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder {
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _stop_compute, 0));
return D;
} // split_overlap_ag
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment