Commit c5cbd64b authored by Rick Ho's avatar Rick Ho
Browse files

swipe test not passing

parent fe5a9cda
#include <cstdio>
#include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h>
......@@ -88,7 +89,9 @@ std::vector<torch::Tensor> _swipe_once(
ncclCommUserRank(smgr->ncclcomm, &rank);
cudaSetDevice(device_idx);
auto cap = capacity.item<long>();
auto capacity_new = capacity.clone();
auto cap = capacity_new.item<long>();
// fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert);
long batch_size = gate_idx.size(0);
auto gate_idx_cpu = gate_idx.cpu();
......@@ -98,16 +101,17 @@ std::vector<torch::Tensor> _swipe_once(
long *lec = new long[n_worker];
memset(lec, 0, n_worker * sizeof(long));
for (long i = 0; i < batch_size; ++i) {
++lec[gidx[i] % n_expert];
++lec[gidx[i] / n_expert];
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
long *gec = _d2h(d_gec, n_expert);
long *gec = _d2h(d_gec, n_worker);
// fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]);
/* Limit number of incoming samples */
long *drop_count = new long[n_worker];
memset(drop_count, 0, n_worker * sizeof(long));
for (long i = 0; i < n_expert; ++i) {
for (long i = 0; i < n_worker; ++i) {
if (cap >= gec[i]) {
drop_count[i] = 0;
cap -= gec[i];
......@@ -118,10 +122,11 @@ std::vector<torch::Tensor> _swipe_once(
}
}
// fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]);
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_expert, smgr);
_d2h(d_lec, lec, n_expert);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
_d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
......@@ -129,12 +134,14 @@ std::vector<torch::Tensor> _swipe_once(
_d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(d_gcap + rank, &cap, n_worker);
_h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream());
auto gcap = _d2h(d_gcap, n_worker);
cudaDeviceSynchronize();
/* Re-assign counts */
// fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]);
/* Re-assign and update counters */
for (long i = 0, j = 0; i < n_worker; ++i) {
while (drop_count[i] > 0) {
if (drop_count[i] > gcap[j]) {
......@@ -148,8 +155,9 @@ std::vector<torch::Tensor> _swipe_once(
}
}
}
// fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]);
for (long i = 0; i < batch_size; ++i) {
auto widx = gidx[i] % n_expert;
auto widx = gidx[i] / n_expert;
if (lec[widx] > 0) {
--lec[widx];
} else {
......@@ -162,9 +170,22 @@ std::vector<torch::Tensor> _swipe_once(
}
for (; lec[k] == 0; ++k);
--lec[gidx[i] = k * n_expert + bias];
// fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k);
}
*capacity_new.data_ptr<long>() = cap;
// fprintf(stderr, "%d all done\n", rank);
return {gate_idx_cpu, capacity};
delete [] drop_count;
delete [] lec;
delete [] gec;
delete [] gcap;
cudaFree(d_dropcount);
cudaFree(d_lec);
cudaFree(d_gec);
cudaFree(d_gcap);
return {gate_idx_cpu, capacity_new};
}
#undef UPDATE_COUNTERS
......
......@@ -13,22 +13,20 @@ import fmoe_cuda as fmoe_native
class SwipeGate(NaiveGate):
requires_moe_group = True
def __init__(self, d_model, num_expert, world_size, topk=2):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(d_model, num_expert, world_size, top_k)
def swipe_once(self, idx, capacity):
def swipe_once(self, idx, capacity, bias):
with torch.no_grad():
idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
self.num_expert, self.world_size)
self.num_expert, self.world_size, bias)
idx_new = idx_new.to(idx.device)
return idx_new, capacity
def forward(self, inp):
score = self.gate(inp)
_, orig_idx = torch.topk(gate_score, k=self.top_k, dim=-1)
_, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
if not self.training:
topk_val = F.softmax(topk_val, dim=-1)
......@@ -38,10 +36,14 @@ class SwipeGate(NaiveGate):
dtype=torch.long)
topk_idxs = []
topk_vals = []
idx_x = torch.arange(inp.shape[0], device=inp.device)
for k in range(self.top_k):
idx, capacity = self.swipe_once(orig_idx[:, k], capacity)
idx, capacity = self.swipe_once(orig_idx[:, k], capacity,
k % self.num_expert)
topk_vals.append(score[idx_x, idx])
topk_idxs.append(idx)
topk_idx = torch.stack(topk_idxs).transpose(0, 1)
topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
topk_val = torch.stack(topk_vals).transpose(0, 1)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
......@@ -128,9 +128,6 @@ class FMoE(nn.Module):
self.mask_dict = mask_dict
self.moe_group = moe_group
if hasattr(self.gate, 'requires_moe_group'):
setattr(self.gate, 'moe_gruop', self.moe_group)
def expert_fn(self, inp, fwd_expert_count):
r"""
The default expert function which either calls the experts as a whole
......
......@@ -5,12 +5,23 @@ from typing import Dict
import pytest
import torch
import torch.distributed as dist
from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
......
......@@ -9,17 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("d_model", [1024])
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from fmoe.gates.swipe_gate import SwipeGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_gate(world_size, d_model, batch_size, n_expert, top_k):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_gate',
world_size,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'top_k': top_k
},
script=__file__
)
def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
_ensure_initialized()
gate = SwipeGate(d_model, n_expert, dist.get_world_size()).cuda()
x = torch.rand(batch_size, d_model).cuda()
ensure_comm(x, None)
topk_idx, topk_val = gate(x)
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_once(world_size, batch_size, n_expert):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_once',
world_size,
{
'batch_size': batch_size,
'n_expert': n_expert
},
script=__file__
)
def _test_swipe_once(batch_size, n_expert):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda()
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
capacity = torch.scalar_tensor(batch_size, dtype=torch.long)
ensure_comm(idx, None)
sys.stderr.write('{} Before swipe gate {}, capacity {}\n'.format(rank, idx, capacity))
new_idx, new_cap = gate.swipe_once(idx, capacity, 0)
sys.stderr.write('{} final gte {}, cap {}\n'.format(rank, new_idx, new_cap))
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
# test_swipe_gate(8, 4, 8, 4, 2)
test_swipe_once(8, 8, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment