Commit 8a56481b authored by Rick Ho's avatar Rick Ho
Browse files

swipe pass test

parent c5cbd64b
...@@ -91,7 +91,6 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -91,7 +91,6 @@ std::vector<torch::Tensor> _swipe_once(
auto capacity_new = capacity.clone(); auto capacity_new = capacity.clone();
auto cap = capacity_new.item<long>(); auto cap = capacity_new.item<long>();
// fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert);
long batch_size = gate_idx.size(0); long batch_size = gate_idx.size(0);
auto gate_idx_cpu = gate_idx.cpu(); auto gate_idx_cpu = gate_idx.cpu();
...@@ -106,7 +105,6 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -106,7 +105,6 @@ std::vector<torch::Tensor> _swipe_once(
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker); long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
long *gec = _d2h(d_gec, n_worker); long *gec = _d2h(d_gec, n_worker);
// fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]);
/* Limit number of incoming samples */ /* Limit number of incoming samples */
long *drop_count = new long[n_worker]; long *drop_count = new long[n_worker];
...@@ -122,7 +120,6 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -122,7 +120,6 @@ std::vector<torch::Tensor> _swipe_once(
} }
} }
// fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]);
/* Send limit information back */ /* Send limit information back */
_h2d(gec, d_gec, n_worker); _h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
...@@ -138,9 +135,7 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -138,9 +135,7 @@ std::vector<torch::Tensor> _swipe_once(
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->stream());
auto gcap = _d2h(d_gcap, n_worker); auto gcap = _d2h(d_gcap, n_worker);
cudaDeviceSynchronize();
// fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]);
/* Re-assign and update counters */ /* Re-assign and update counters */
for (long i = 0, j = 0; i < n_worker; ++i) { for (long i = 0, j = 0; i < n_worker; ++i) {
while (drop_count[i] > 0) { while (drop_count[i] > 0) {
...@@ -155,7 +150,6 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -155,7 +150,6 @@ std::vector<torch::Tensor> _swipe_once(
} }
} }
} }
// fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]);
for (long i = 0; i < batch_size; ++i) { for (long i = 0; i < batch_size; ++i) {
auto widx = gidx[i] / n_expert; auto widx = gidx[i] / n_expert;
if (lec[widx] > 0) { if (lec[widx] > 0) {
...@@ -169,11 +163,10 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -169,11 +163,10 @@ std::vector<torch::Tensor> _swipe_once(
continue; continue;
} }
for (; lec[k] == 0; ++k); for (; lec[k] == 0; ++k);
--lec[gidx[i] = k * n_expert + bias]; --lec[k];
// fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k); gidx[i] = k * n_expert + bias;
} }
*capacity_new.data_ptr<long>() = cap; *capacity_new.data_ptr<long>() = cap;
// fprintf(stderr, "%d all done\n", rank);
delete [] drop_count; delete [] drop_count;
delete [] lec; delete [] lec;
......
...@@ -42,7 +42,6 @@ def _test_swipe_gate(d_model, batch_size, n_expert, top_k): ...@@ -42,7 +42,6 @@ def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4]) @pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("world_size", [2, 4, 8])
...@@ -65,16 +64,16 @@ def _test_swipe_once(batch_size, n_expert): ...@@ -65,16 +64,16 @@ def _test_swipe_once(batch_size, n_expert):
world_size = dist.get_world_size() world_size = dist.get_world_size()
gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda() gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda()
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda() idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
capacity = torch.scalar_tensor(batch_size, dtype=torch.long) capacity = torch.scalar_tensor(batch_size * 2, dtype=torch.long)
ensure_comm(idx, None) ensure_comm(idx, None)
sys.stderr.write('{} Before swipe gate {}, capacity {}\n'.format(rank, idx, capacity))
new_idx, new_cap = gate.swipe_once(idx, capacity, 0) new_idx, new_cap = gate.swipe_once(idx, capacity, 0)
sys.stderr.write('{} final gte {}, cap {}\n'.format(rank, new_idx, new_cap)) idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
new_idx, new_cap = gate.swipe_once(idx, new_cap, 0)
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 3: if len(sys.argv) >= 3:
args = json.loads(sys.argv[2]) args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else: else:
# test_swipe_gate(8, 4, 8, 4, 2) test_swipe_gate(8, 4, 8, 4, 2)
test_swipe_once(8, 8, 4) # test_swipe_once(8, 800, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment