Commit fa5f45f0 authored by Rick Ho's avatar Rick Ho
Browse files

fix bugs to run megatron with gshard gate

parent 7f6463f0
......@@ -4,7 +4,7 @@
/*
* note that due to limit of cuda atomic operator, capacity should be int32
*/
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_worker) {
CHECK_INPUT(expert_count);
......@@ -16,16 +16,22 @@ std::vector<torch::Tensor> _limit_by_capacity(
capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(),
n_expert, n_worker, smgr);
return {expert_count_ack};
return expert_count_ack;
}
void _prune_gate_by_capacity(
torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker) {
auto smgr = getCudaStreamManager(expert_count.device().index());
auto batch_size = gate_idx.numel();
auto opt = torch::TensorOptions()
.dtype(gate_idx.dtype())
.device(gate_idx.device());
auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
fmoe_cuda_prune_gate_by_capacity_impl(
gate_idx.data_ptr<long>(),
new_gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, n_worker, smgr);
return new_gate_idx;
}
......@@ -31,24 +31,28 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
}
__global__
void prune_gate_by_capacity_kernel(long* gate_idx, int* ec,
void prune_gate_by_capacity_kernel(const long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
int orig_cap = atomicSub(ec + gate_idx[i], 1);
if (orig_cap <= 0) {
gate_idx[i] = -1;
new_gate_idx[i] = -1;
} else {
new_gate_idx[i] = gate_idx[i];
}
}
}
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, int* ec,
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
gate_idx, ec, batch_size, n_expert, n_worker
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
}
......@@ -43,10 +43,10 @@ std::vector<torch::Tensor> _linear_backward(
);
// balancing
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts);
void _prune_gate_by_capacity(
torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker);
......
......@@ -189,7 +189,7 @@ class MOEGather(Function):
global_output_buf,
local_expert_count,
global_expert_count,
local_batch_size,
pos.shape[0],
world_size,
)
else:
......
......@@ -10,7 +10,8 @@ from .utils import limit_by_capacity
class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
capacity=(1.2, 2.4), random_routing=True):
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = True
......@@ -34,7 +35,8 @@ class GShardGate(NaiveGate):
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity)
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
......
......@@ -14,8 +14,9 @@ class SwitchGate(NaiveGate):
A switch gate implementation
"""
def __init__(self, d_model, num_expert, world_size,
def __init__(self, d_model, num_expert, world_size, topk=1,
switch_eps=.1, capacity=(1.2, 2.4)):
assert topk == 1, 'topk should be 1 in switch'
super().__init__(d_model, num_expert, world_size, top_k=1)
self.switch_eps = switch_eps
self.capacity = capacity
......@@ -42,7 +43,8 @@ class SwitchGate(NaiveGate):
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0])
limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity)
_new_lec, _new_gec, top1_idx = limit_by_capacity(
top1_idx, self.num_expert, self.world_size, capacity)
valid_idx = top1_idx[top1_idx > -1]
fraction_expert = torch.scatter_add(
......
......@@ -7,19 +7,20 @@ import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
with torch.no_grad():
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
require_pos=False)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
new_gec = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
new_lec, = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
new_lec, = fmoe_native.expert_exchange(new_gec, num_expert,
world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec
return new_lec, new_gec, topk_idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment