#include #include "balancing.cuh" #include "global_exchange.h" #include /* * note that due to limit of cuda atomic operator, capacity should be int32 */ torch::Tensor _limit_by_capacity( torch::Tensor expert_count, torch::Tensor capacity, long n_expert, long n_worker) { CHECK_INPUT(expert_count); CHECK_INPUT(capacity); auto expert_count_ack = torch::empty_like(expert_count); auto smgr = getCudaStreamManager(expert_count.device().index()); fmoe_cuda_limit_by_capacity_impl( expert_count.data_ptr(), capacity.data_ptr(), expert_count_ack.data_ptr(), n_expert, n_worker, smgr); return expert_count_ack; } torch::Tensor _prune_gate_by_capacity( torch::Tensor gate_idx, torch::Tensor expert_count, long n_expert, long n_worker) { auto smgr = getCudaStreamManager(expert_count.device().index()); auto batch_size = gate_idx.numel(); auto opt = torch::TensorOptions() .dtype(gate_idx.dtype()) .device(gate_idx.device()); auto new_gate_idx = torch::empty(gate_idx.sizes(), opt); fmoe_cuda_prune_gate_by_capacity_impl( gate_idx.data_ptr(), new_gate_idx.data_ptr(), expert_count.data_ptr(), batch_size, n_expert, n_worker, smgr); return new_gate_idx; } template T* _cudamalloc(size_t sz) { T* dptr; cudaMalloc(&dptr, sz * sizeof(T)); return dptr; } template T* _h2d(const T* hptr, T* dptr, size_t sz) { cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice); return dptr; } template T* _h2d(T* hptr, size_t sz) { T* dptr = _cudamalloc(sz); return _h2d(hptr, dptr, sz); } template T* _d2h(const T* dptr, T* hptr, size_t sz) { cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost); return hptr; } template T* _d2h(const T* dptr, size_t sz) { T* hptr = new T[sz]; return _d2h(dptr, hptr, sz); } #ifdef FMOE_USE_NCCL #include #define UPDATE_COUNTERS(__count__) { \ if (i == rank) { \ lec[j] += (__count__); \ } \ if (j == rank) { \ gec[i] += (__count__); \ cap -= (__count__); \ } \ } std::vector _swipe_once( torch::Tensor gate_idx, torch::Tensor capacity, long n_expert, long n_worker, long bias) { auto device_idx = gate_idx.device().index(); auto smgr = getCudaStreamManager(device_idx); int rank; ncclCommUserRank(smgr->ncclcomm, &rank); cudaSetDevice(device_idx); auto capacity_new = capacity.clone(); auto cap = capacity_new.item(); // fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert); long batch_size = gate_idx.size(0); auto gate_idx_cpu = gate_idx.cpu(); long* gidx = gate_idx_cpu.data_ptr(); /* Local count and exchange */ long *lec = new long[n_worker]; memset(lec, 0, n_worker * sizeof(long)); for (long i = 0; i < batch_size; ++i) { ++lec[gidx[i] / n_expert]; } long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc(n_worker); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); long *gec = _d2h(d_gec, n_worker); // fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]); /* Limit number of incoming samples */ long *drop_count = new long[n_worker]; memset(drop_count, 0, n_worker * sizeof(long)); for (long i = 0; i < n_worker; ++i) { if (cap >= gec[i]) { drop_count[i] = 0; cap -= gec[i]; } else { drop_count[i] = gec[i] - cap; gec[i] = cap; cap = 0; } } // fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]); /* Send limit information back */ _h2d(gec, d_gec, n_worker); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); _d2h(d_lec, lec, n_worker); auto d_dropcount = _h2d(drop_count, n_worker); ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, smgr->ncclcomm, smgr->stream()); _d2h(d_dropcount, drop_count, n_worker); auto d_gcap = _cudamalloc(n_worker); _h2d(&cap, d_gcap + rank, 1); ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, smgr->ncclcomm, smgr->stream()); auto gcap = _d2h(d_gcap, n_worker); cudaDeviceSynchronize(); // fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]); /* Re-assign and update counters */ for (long i = 0, j = 0; i < n_worker; ++i) { while (drop_count[i] > 0) { if (drop_count[i] > gcap[j]) { drop_count[i] -= gcap[j]; UPDATE_COUNTERS(gcap[j]); ++j; } else { gcap[j] -= drop_count[i]; UPDATE_COUNTERS(drop_count[i]); break; } } } // fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]); for (long i = 0; i < batch_size; ++i) { auto widx = gidx[i] / n_expert; if (lec[widx] > 0) { --lec[widx]; } else { gidx[i] = -1; } } for (long i = 0, k = 0; i < batch_size; ++i) { if (gidx[i] != -1) { continue; } for (; lec[k] == 0; ++k); --lec[gidx[i] = k * n_expert + bias]; // fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k); } *capacity_new.data_ptr() = cap; // fprintf(stderr, "%d all done\n", rank); delete [] drop_count; delete [] lec; delete [] gec; delete [] gcap; cudaFree(d_dropcount); cudaFree(d_lec); cudaFree(d_gec); cudaFree(d_gcap); return {gate_idx_cpu, capacity_new}; } #undef UPDATE_COUNTERS #endif