Commit fa5f45f0 authored by Rick Ho's avatar Rick Ho
Browse files

fix bugs to run megatron with gshard gate

parent 7f6463f0
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
/* /*
* note that due to limit of cuda atomic operator, capacity should be int32 * note that due to limit of cuda atomic operator, capacity should be int32
*/ */
std::vector<torch::Tensor> _limit_by_capacity( torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity, torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_worker) { long n_expert, long n_worker) {
CHECK_INPUT(expert_count); CHECK_INPUT(expert_count);
...@@ -16,16 +16,22 @@ std::vector<torch::Tensor> _limit_by_capacity( ...@@ -16,16 +16,22 @@ std::vector<torch::Tensor> _limit_by_capacity(
capacity.data_ptr<int>(), capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(), expert_count_ack.data_ptr<long>(),
n_expert, n_worker, smgr); n_expert, n_worker, smgr);
return {expert_count_ack}; return expert_count_ack;
} }
void _prune_gate_by_capacity( torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count, torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker) { long n_expert, long n_worker) {
auto smgr = getCudaStreamManager(expert_count.device().index()); auto smgr = getCudaStreamManager(expert_count.device().index());
auto batch_size = gate_idx.numel(); auto batch_size = gate_idx.numel();
auto opt = torch::TensorOptions()
.dtype(gate_idx.dtype())
.device(gate_idx.device());
auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
fmoe_cuda_prune_gate_by_capacity_impl( fmoe_cuda_prune_gate_by_capacity_impl(
gate_idx.data_ptr<long>(), gate_idx.data_ptr<long>(),
new_gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(), expert_count.data_ptr<int>(),
batch_size, n_expert, n_worker, smgr); batch_size, n_expert, n_worker, smgr);
return new_gate_idx;
} }
...@@ -31,24 +31,28 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, ...@@ -31,24 +31,28 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
} }
__global__ __global__
void prune_gate_by_capacity_kernel(long* gate_idx, int* ec, void prune_gate_by_capacity_kernel(const long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker) { const long batch_size, const long n_expert, const long n_worker) {
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) { if (i < batch_size) {
int orig_cap = atomicSub(ec + gate_idx[i], 1); int orig_cap = atomicSub(ec + gate_idx[i], 1);
if (orig_cap <= 0) { if (orig_cap <= 0) {
gate_idx[i] = -1; new_gate_idx[i] = -1;
} else {
new_gate_idx[i] = gate_idx[i];
} }
} }
} }
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, int* ec, void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker, const long batch_size, const long n_expert, const long n_worker,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024)); dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024); dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
gate_idx, ec, batch_size, n_expert, n_worker gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
); );
smgr->sync(1); smgr->sync(1);
} }
...@@ -43,10 +43,10 @@ std::vector<torch::Tensor> _linear_backward( ...@@ -43,10 +43,10 @@ std::vector<torch::Tensor> _linear_backward(
); );
// balancing // balancing
std::vector<torch::Tensor> _limit_by_capacity( torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity, torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts); long n_expert, long n_experts);
void _prune_gate_by_capacity( torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count, torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker); long n_expert, long n_worker);
......
...@@ -189,7 +189,7 @@ class MOEGather(Function): ...@@ -189,7 +189,7 @@ class MOEGather(Function):
global_output_buf, global_output_buf,
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
local_batch_size, pos.shape[0],
world_size, world_size,
) )
else: else:
......
...@@ -10,7 +10,8 @@ from .utils import limit_by_capacity ...@@ -10,7 +10,8 @@ from .utils import limit_by_capacity
class GShardGate(NaiveGate): class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, def __init__(self, d_model, num_expert, world_size,
capacity=(1.2, 2.4), random_routing=True): topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2) super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity self.capacity = capacity
self.random_routing = True self.random_routing = True
...@@ -34,7 +35,8 @@ class GShardGate(NaiveGate): ...@@ -34,7 +35,8 @@ class GShardGate(NaiveGate):
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0]) capacity = math.ceil(cap_rate * x.shape[0])
limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity) _new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing: if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
......
...@@ -14,8 +14,9 @@ class SwitchGate(NaiveGate): ...@@ -14,8 +14,9 @@ class SwitchGate(NaiveGate):
A switch gate implementation A switch gate implementation
""" """
def __init__(self, d_model, num_expert, world_size, def __init__(self, d_model, num_expert, world_size, topk=1,
switch_eps=.1, capacity=(1.2, 2.4)): switch_eps=.1, capacity=(1.2, 2.4)):
assert topk == 1, 'topk should be 1 in switch'
super().__init__(d_model, num_expert, world_size, top_k=1) super().__init__(d_model, num_expert, world_size, top_k=1)
self.switch_eps = switch_eps self.switch_eps = switch_eps
self.capacity = capacity self.capacity = capacity
...@@ -42,7 +43,8 @@ class SwitchGate(NaiveGate): ...@@ -42,7 +43,8 @@ class SwitchGate(NaiveGate):
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0]) capacity = math.ceil(cap_rate * inp.shape[0])
limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity) _new_lec, _new_gec, top1_idx = limit_by_capacity(
top1_idx, self.num_expert, self.world_size, capacity)
valid_idx = top1_idx[top1_idx > -1] valid_idx = top1_idx[top1_idx > -1]
fraction_expert = torch.scatter_add( fraction_expert = torch.scatter_add(
......
...@@ -7,19 +7,20 @@ import fmoe_cuda as fmoe_native ...@@ -7,19 +7,20 @@ import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity): def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
with torch.no_grad():
capacity = torch.ones(num_expert, dtype=torch.int32, capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size, pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
require_pos=False) require_pos=False)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity, new_gec = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size) num_expert, world_size)
if world_size > 1: if world_size > 1:
new_lec, = fmoe_native.expert_exchange(new_gec, num_expert, world_size) new_lec, = fmoe_native.expert_exchange(new_gec, num_expert,
world_size)
else: else:
new_lec = new_gec new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx, topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size) new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec, topk_idx
return new_lec, new_gec
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment