"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dbc1d505f018807089ea0da575f40ba22e8b4709"
Commit 8052b3c0 authored by Rick Ho's avatar Rick Ho
Browse files

cpu swipe

parent 4a9ef7fd
#include "balancing.cuh" #include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h> #include <torch/extension.h>
/* /*
...@@ -35,3 +36,137 @@ torch::Tensor _prune_gate_by_capacity( ...@@ -35,3 +36,137 @@ torch::Tensor _prune_gate_by_capacity(
batch_size, n_expert, n_worker, smgr); batch_size, n_expert, n_worker, smgr);
return new_gate_idx; return new_gate_idx;
} }
template<class T>
T* _cudamalloc(size_t sz) {
T* dptr;
cudaMalloc(&dptr, sz * sizeof(T));
return dptr;
}
template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice);
return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
T* dptr = _cudamalloc<T>(sz);
return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost);
return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
T* hptr = new T[sz];
return _d2h(dptr, hptr, sz);
}
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define UPDATE_COUNTERS(__count__) { \
if (i == rank) { \
lec[j] += (__count__); \
} \
if (j == rank) { \
gec[i] += (__count__); \
cap -= (__count__); \
} \
}
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity,
long n_expert, long n_worker, long bias) {
auto device_idx = gate_idx.device().index();
auto smgr = getCudaStreamManager(device_idx);
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
cudaSetDevice(device_idx);
auto cap = capacity.item<long>();
long batch_size = gate_idx.size(0);
auto gate_idx_cpu = gate_idx.cpu();
long* gidx = gate_idx_cpu.data_ptr<long>();
/* Local count and exchange */
long *lec = new long[n_worker];
memset(lec, 0, n_worker * sizeof(long));
for (long i = 0; i < batch_size; ++i) {
++lec[gidx[i] % n_expert];
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
long *gec = _d2h(d_gec, n_expert);
/* Limit number of incoming samples */
long *drop_count = new long[n_worker];
memset(drop_count, 0, n_worker * sizeof(long));
for (long i = 0; i < n_expert; ++i) {
if (cap >= gec[i]) {
drop_count[i] = 0;
cap -= gec[i];
} else {
drop_count[i] = gec[i] - cap;
gec[i] = cap;
cap = 0;
}
}
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_expert, smgr);
_d2h(d_lec, lec, n_expert);
auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream());
_d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(d_gcap + rank, &cap, n_worker);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream());
auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign counts */
for (long i = 0, j = 0; i < n_worker; ++i) {
while (drop_count[i] > 0) {
if (drop_count[i] > gcap[j]) {
drop_count[i] -= gcap[j];
UPDATE_COUNTERS(gcap[j]);
++j;
} else {
gcap[j] -= drop_count[i];
UPDATE_COUNTERS(drop_count[i]);
break;
}
}
}
for (long i = 0; i < batch_size; ++i) {
auto widx = gidx[i] % n_expert;
if (lec[widx] > 0) {
--lec[widx];
} else {
gidx[i] = -1;
}
}
for (long i = 0, k = 0; i < batch_size; ++i) {
if (gidx[i] != -1) {
continue;
}
for (; lec[k] == 0; ++k);
--lec[gidx[i] = k * n_expert + bias];
}
return {gate_idx_cpu, capacity};
}
#undef UPDATE_COUNTERS
#endif
...@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity( ...@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity(
torch::Tensor _prune_gate_by_capacity( torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count, torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker); long n_expert, long n_worker);
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity_tensor,
long n_expert, long n_worker, long bias);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
...@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)"); m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)"); m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm"); m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
#endif #endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)"); m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
...@@ -5,6 +5,33 @@ ...@@ -5,6 +5,33 @@
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
#include <nccl.h> #include <nccl.h>
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
torch::Tensor _expert_exchange( torch::Tensor _expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
long n_expert, long n_workers) { long n_expert, long n_workers) {
...@@ -31,7 +58,7 @@ torch::Tensor _global_scatter( ...@@ -31,7 +58,7 @@ torch::Tensor _global_scatter(
auto global_input_buf = input_buf.new_empty({batch_size, in_feat}); auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_global_scatter", ([&] { "fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>( fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
...@@ -57,7 +84,7 @@ torch::Tensor _global_gather( ...@@ -57,7 +84,7 @@ torch::Tensor _global_gather(
auto local_output_buf = output_buf.new_empty({batch_size, out_feat}); auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index()); auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"fmoe_cuda_global_gather", ([&] { "fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>( fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
......
...@@ -2,30 +2,11 @@ ...@@ -2,30 +2,11 @@
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
void fmoe_cuda_expert_exchange_impl( void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count, const long* local_expert_count,
long* global_expert_count, long* global_expert_count,
int n_expert, int world_size, int n_expert, int world_size,
CudaStreamManager* smgr) { CudaStreamManager* smgr);
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
template<typename scalar_t> template<typename scalar_t>
void fmoe_cuda_global_scatter_impl( void fmoe_cuda_global_scatter_impl(
...@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl( ...@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl(
int idx = i + j * n_expert; int idx = i + j * n_expert;
if (local_expert_count[idx]) { if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend( NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat, local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t), local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->stream(0)));
...@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl( ...@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl(
} }
if (local_expert_count[idx]) { if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv( NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat, local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t), local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->stream(0)));
......
r"""
Balanced gate using SWIPE algorithm
"""
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
class SwipeGate(NaiveGate):
requires_moe_group = True
def __init__(self, d_model, num_expert, world_size, topk=2):
super().__init__(d_model, num_expert, world_size, top_k)
def swipe_once(self, idx, capacity):
with torch.no_grad():
idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
self.num_expert, self.world_size)
idx_new = idx_new.to(idx.device)
return idx_new, capacity
def forward(self, inp):
score = self.gate(inp)
_, orig_idx = torch.topk(gate_score, k=self.top_k, dim=-1)
if not self.training:
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
dtype=torch.long)
topk_idxs = []
for k in range(self.top_k):
idx, capacity = self.swipe_once(orig_idx[:, k], capacity)
topk_idxs.append(idx)
topk_idx = torch.stack(topk_idxs).transpose(0, 1)
topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment