Commit 1aced6d8 authored by Rick Ho's avatar Rick Ho
Browse files

balancing cuda code

parent 6cb6bbe4
#include "balancing.cuh"
#include <torch/extension.h>
/*
* note that due to limit of cuda atomic operator, capacity should be int32
*/
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts) {
auto expert_count_ack = torch::empty_like(expert_count);
auto smgr = getCudaStreamManager(expert_count.device().index());
fmoe_cuda_limit_by_capacity_impl(
expert_count.data_ptr<long>(),
capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(),
n_expert, n_workers, smgr);
}
#include "stream_manager.h"
#include "utils/fmoe_utils.h"
#include <cuda.h>
__global__
void limit_by_capacity_kernel(const long* ec, int* cap, long* eca,
const long n_expert, const long n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
int proposal = ec[wid * n_expert + eid];
int cap_left = atomicSub(cap + eid, proposal);
if (cap_left >= proposal) {
eca[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) {
eca[wid * n_expert + eid] = cap_left;
} else {
eca[wid * n_expert + eid] = 0;
}
}
}
void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
long* eca, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
}
...@@ -10,7 +10,27 @@ import fmoe_cuda ...@@ -10,7 +10,27 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def moe_prepare_forward(gate, num_expert, world_size, comm=None): def count_by_gate(gate, num_expert, world_size, comm):
# TODO: support -1 in gate, which means ignore this input
with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long
)
local_expert_count.index_put_((gate_idx.long(),), gate_count)
if world_size > 1:
(global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
else:
global_expert_count = local_expert_count
return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size, comm=None):
r""" r"""
Prepare necessary information from gate output for MoE computation. Prepare necessary information from gate output for MoE computation.
...@@ -26,20 +46,9 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None): ...@@ -26,20 +46,9 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm = get_torch_default_comm() comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, gate) fmoe_cuda.ensure_nccl(comm, gate)
pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size)
with torch.no_grad(): with torch.no_grad():
_, pos = torch.sort(gate)
gate_idx, gate_count = torch.unique(gate, return_counts=True)
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.long
)
local_expert_count.index_put_((gate_idx.long(),), gate_count)
if world_size > 1:
(global_expert_count,) = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
else:
global_expert_count = local_expert_count
fwd_expert_count = global_expert_count.view(world_size, fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0) num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item()) fwd_batch_size = int(fwd_expert_count.sum().item())
......
...@@ -4,6 +4,8 @@ Balanced gate with GShard's policy (Google, 2020) ...@@ -4,6 +4,8 @@ Balanced gate with GShard's policy (Google, 2020)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .naive_gate import NaiveGate from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
class GShardGate(NaiveGate): class GShardGate(NaiveGate):
...@@ -27,6 +29,18 @@ class GShardGate(NaiveGate): ...@@ -27,6 +29,18 @@ class GShardGate(NaiveGate):
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2) loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss) self.set_loss(loss)
# TODO: capacity limit cap_rate = self.capacity[0 if self.training else 1]
capacity = torch.ones(self.num_expert, dtype=torch.int32)
capacity *= math.ceil(cap_rate * x.shape[0])
pos, lec, gec = count_by_gate(gate, self.num_expert, self.world_size)
new_gec = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size)
if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size)
else:
new_lec = new_gec
# TODO: re-assign gate
return topk_idx, topk_val return topk_idx, topk_val
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment