Commit 226e0779 authored by Rick Ho's avatar Rick Ho
Browse files

bug fix of swipe

parent 4f9f77f8
...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
} }
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker); long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker); long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */ /* Limit number of incoming samples */
...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */ /* Send limit information back */
_h2d(gec, d_gec, n_worker); _h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker); _d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker); auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker); _d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker); auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1); _h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker); auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */ /* Re-assign and update counters */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment