"vscode:/vscode.git/clone" did not exist on "99f726ff93804287f32e46550f00c1066cc0db22"
Unverified Commit 3397bc19 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #84 from laekov/swipe

SWIPE balance strategy
parents 4a9ef7fd 206f267e
#include <cstdio>
#include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h>
/*
......@@ -35,3 +37,150 @@ torch::Tensor _prune_gate_by_capacity(
batch_size, n_expert, n_worker, smgr);
return new_gate_idx;
}
template<class T>
T* _cudamalloc(size_t sz) {
T* dptr;
cudaMalloc(&dptr, sz * sizeof(T));
return dptr;
}
template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
cudaMemcpy(dptr, hptr, sz * sizeof(T), cudaMemcpyHostToDevice);
return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
T* dptr = _cudamalloc<T>(sz);
return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
cudaMemcpy(hptr, dptr, sz * sizeof(T), cudaMemcpyDeviceToHost);
return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
T* hptr = new T[sz];
return _d2h(dptr, hptr, sz);
}
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define UPDATE_COUNTERS(__count__) { \
if (i == rank) { \
lec[j] += (__count__); \
} \
if (j == rank) { \
gec[i] += (__count__); \
cap -= (__count__); \
} \
}
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity,
long n_expert, long n_worker, long bias) {
auto device_idx = gate_idx.device().index();
auto smgr = getCudaStreamManager(device_idx);
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
cudaSetDevice(device_idx);
auto capacity_new = capacity.clone();
auto cap = capacity_new.item<long>();
long batch_size = gate_idx.size(0);
auto gate_idx_cpu = gate_idx.cpu();
long* gidx = gate_idx_cpu.data_ptr<long>();
/* Local count and exchange */
long *lec = new long[n_worker];
memset(lec, 0, n_worker * sizeof(long));
for (long i = 0; i < batch_size; ++i) {
++lec[gidx[i] / n_expert];
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */
long *drop_count = new long[n_worker];
memset(drop_count, 0, n_worker * sizeof(long));
for (long i = 0; i < n_worker; ++i) {
if (cap >= gec[i]) {
drop_count[i] = 0;
cap -= gec[i];
} else {
drop_count[i] = gec[i] - cap;
gec[i] = cap;
cap = 0;
}
}
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
_d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream());
_d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream());
auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */
for (long i = 0, j = 0; i < n_worker; ++i) {
while (drop_count[i] > 0) {
if (drop_count[i] > gcap[j]) {
drop_count[i] -= gcap[j];
UPDATE_COUNTERS(gcap[j]);
++j;
} else {
gcap[j] -= drop_count[i];
UPDATE_COUNTERS(drop_count[i]);
break;
}
}
}
for (long i = 0; i < batch_size; ++i) {
auto widx = gidx[i] / n_expert;
if (lec[widx] > 0) {
--lec[widx];
} else {
gidx[i] = -1;
}
}
for (long i = 0, k = 0; i < batch_size; ++i) {
if (gidx[i] != -1) {
continue;
}
for (; lec[k] == 0; ++k);
--lec[k];
gidx[i] = k * n_expert + bias;
}
*capacity_new.data_ptr<long>() = cap;
delete [] drop_count;
delete [] lec;
delete [] gec;
delete [] gcap;
cudaFree(d_dropcount);
cudaFree(d_lec);
cudaFree(d_gec);
cudaFree(d_gcap);
return {gate_idx_cpu, capacity_new};
}
#undef UPDATE_COUNTERS
#endif
......@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity(
torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker);
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity_tensor,
long n_expert, long n_worker, long bias);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
......@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
#endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
......@@ -5,6 +5,33 @@
#ifdef FMOE_USE_NCCL
#include <nccl.h>
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
torch::Tensor _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers) {
......@@ -31,7 +58,7 @@ torch::Tensor _global_scatter(
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
......@@ -57,7 +84,7 @@ torch::Tensor _global_gather(
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
......
......@@ -2,30 +2,11 @@
#ifdef FMOE_USE_NCCL
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}
CudaStreamManager* smgr);
template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
......@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl(
int idx = i + j * n_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
......@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl(
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
......
......@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
from .swipe_gate import SwipeGate
......@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
if clear:
self.loss = None
return loss
@property
def has_loss(self):
return self.loss is not None
r"""
Balanced gate using SWIPE algorithm
"""
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
class SwipeGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(d_model, num_expert, world_size, top_k)
def swipe_once(self, idx, capacity, bias):
with torch.no_grad():
idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
self.num_expert, self.world_size, bias)
idx_new = idx_new.to(idx.device)
return idx_new, capacity
def forward(self, inp):
score = self.gate(inp)
orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
if not self.training:
topk_val = F.softmax(orig_score, dim=-1)
return orig_idx, topk_val
capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
dtype=torch.long)
topk_idxs = []
topk_vals = []
idx_x = torch.arange(inp.shape[0], device=inp.device)
for k in range(self.top_k):
idx, capacity = self.swipe_once(orig_idx[:, k], capacity,
k % self.num_expert)
topk_vals.append(score[idx_x, idx])
topk_idxs.append(idx)
topk_idx = torch.stack(topk_idxs).transpose(0, 1)
topk_val = torch.stack(topk_vals).transpose(0, 1)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
......@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
while hasattr(model, 'module'):
model = model.module
balance_dict_tensor = torch.vstack(
[l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers]
).detach()
losses = [l.mlp.gate.get_loss(clear=True)
for l in model.language_model.transformer.layers
if l.mlp.gate.has_loss]
if len(losses) == 0:
return
balance_dict_tensor = torch.vstack(losses).detach()
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
......
......@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
elif args.balance_strategy == "swipe":
from fmoe.gates import SwipeGate
gate = SwipeGate
elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
......
......@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
if not is_pipeline_last_stage() or not args.balance_strategy:
return output
loss_name = args.balance_strategy + "_loss"
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers]
for l in model.language_model.transformer.layers
if l.mlp.gate.has_loss]
if len(loss_list) == 0:
return output
loss_name = args.balance_strategy + "_loss"
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
......
......@@ -5,12 +5,23 @@ from typing import Dict
import pytest
import torch
import torch.distributed as dist
from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
......
......@@ -9,17 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("d_model", [1024])
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from fmoe.gates.swipe_gate import SwipeGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_gate(world_size, d_model, batch_size, n_expert, top_k):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_gate',
world_size,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'top_k': top_k
},
script=__file__
)
def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
_ensure_initialized()
gate = SwipeGate(d_model, n_expert, dist.get_world_size()).cuda()
x = torch.rand(batch_size, d_model).cuda()
ensure_comm(x, None)
topk_idx, topk_val = gate(x)
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("world_size", [2, 4, 8])
def test_swipe_once(world_size, batch_size, n_expert):
if world_size * n_expert < 2:
pytest.skip("No enough experts")
_run_distributed('_test_swipe_once',
world_size,
{
'batch_size': batch_size,
'n_expert': n_expert
},
script=__file__
)
def _test_swipe_once(batch_size, n_expert):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
gate = SwipeGate(4, n_expert, dist.get_world_size()).cuda()
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
capacity = torch.scalar_tensor(batch_size * 2, dtype=torch.long)
ensure_comm(idx, None)
new_idx, new_cap = gate.swipe_once(idx, capacity, 0)
idx = torch.randint(0, n_expert * world_size, (batch_size,)).cuda()
new_idx, new_cap = gate.swipe_once(idx, new_cap, 0)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
test_swipe_gate(8, 4, 8, 4, 2)
# test_swipe_once(8, 800, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment