Commit 91a5e794 authored by Rick Ho's avatar Rick Ho
Browse files

faster policies

parent 794dd0e6
import os
def float_from_env(key, default=-1):
if key in os.environ:
return float(os.environ[key])
return default
def switch_from_env(key, default=False):
if key in os.environ:
return os.environ[key] in ['1', 'ON']
return default
...@@ -9,6 +9,8 @@ from fmoe.functions import _local_scatter, _local_gather ...@@ -9,6 +9,8 @@ from fmoe.functions import _local_scatter, _local_gather
import fmoe_cuda as fmoe_native import fmoe_cuda as fmoe_native
from fmoe.fastermoe import expert_utils from fmoe.fastermoe import expert_utils
from .shadow_policy import get_shadow_policy
class MoEForward(Function): class MoEForward(Function):
@staticmethod @staticmethod
...@@ -31,6 +33,7 @@ class MoEForward(Function): ...@@ -31,6 +33,7 @@ class MoEForward(Function):
x = x.data x = x.data
with torch.enable_grad(): with torch.enable_grad():
x.requires_grad = True x.requires_grad = True
# To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]]) y0 = expert_fn(x, [x.shape[0]])
ctx.gibs[idx] = x ctx.gibs[idx] = x
...@@ -101,6 +104,9 @@ class MoEForward(Function): ...@@ -101,6 +104,9 @@ class MoEForward(Function):
return (None, None, grad_in, None, None, None, None, None, None, None, None) return (None, None, grad_in, None, None, None, None, None, None, None, None)
policy_fn = None
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None): def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
# TODO: Using multiple tensors as input is to be supported. # TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor)) assert(isinstance(inp, torch.Tensor))
...@@ -114,9 +120,13 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp ...@@ -114,9 +120,13 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, n_expert, world_size) ) = prepare_forward(gate, n_expert, world_size)
# TODO: Expert shadowing is to be supported. Currently using all 0s global policy_fn
if policy_fn is None:
policy_fn = get_shadow_policy(d_model=inp.shape[-1])
if stored_models is None: if stored_models is None:
stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool) stored_models = policy_fn(local_expert_count, global_expert_count,
n_expert, world_size)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
......
import os
import torch
import torch.distributed as dist
from .config import float_from_env, switch_from_env
from fmoe.functions import get_moe_group
def global_policy(local_expert_count, _gec, num_expert, world_size):
r"""
This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
A few parameters are used in this policy.
* `d_model`: feature length of the MLP input and output.
* `alpha`: the ratio of the MLP's hidden size to `d_model`.
* `bw_net`: bandwidth of the network (GBps)
* `bw_mm`: computation throughput of performing GeMM (FLOPs)
"""
bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8)
bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12)
alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2)
d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048)
moe_group = get_moe_group()
local_expert_count = local_expert_count.cuda()
agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)]
dist.all_gather(agecs, local_expert_count, group=moe_group)
all_global_expert_count = torch.stack(agecs)
# TODO: data type other than float
data_size = 4
fwd_expert_counts = all_global_expert_count.sum(1).cpu()
B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True)
alphaH2 = alpha * (d_model ** 2)
B_w = B_ws[0]
comm = float('+inf')
send_feature_time = d_model * data_size / bw_net
send_model_time = 2 * alphaH2 * data_size / bw_net
comp_time = 4 * alphaH2 / bw_mm
lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w
res = torch.zeros(world_size * num_expert, dtype=torch.bool)
shadow_time = 0
for i, index in enumerate(indices):
if i + 1 == indices.numel():
break
B_k = B_ws[i + 1]
shadow_time += send_model_time
lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time
if lat_new < lat_base:
lat_base = lat_new
res[index] = True
else:
break
return res
def no_shadow_policy(_lec, _gec, num_expert, world_size):
res = torch.zeros(world_size * num_expert, dtype=bool)
return res
def get_shadow_policy(d_model=None):
if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ:
os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model)
if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'):
return no_policy
return global_policy
...@@ -10,12 +10,21 @@ import fmoe_cuda ...@@ -10,12 +10,21 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
_moe_group = None
def ensure_comm(t, comm): def ensure_comm(t, comm):
if comm is None: if comm is None:
comm = get_torch_default_comm() comm = get_torch_default_comm()
global _moe_group
_moe_group = comm
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def get_moe_group():
return _moe_group
def count_by_gate(gate, num_expert, world_size, require_pos=True): def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
......
...@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather ...@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
from .fastermoe.config import switch_from_env
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
...@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, * ...@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, *
return outp return outp
if os.environ.get('FMOE_FASTER_SCHEDULE_ENABLE', '0') in ['1', 'ON']: fmoe_faster_schedule = False
if switch_from_env('FMOE_FASTER_SCHEDULE_ENABLE', False):
fmoe_faster_schedule = True
from .fastermoe.schedule import _fmoe_general_global_forward from .fastermoe.schedule import _fmoe_general_global_forward
......
...@@ -3,6 +3,7 @@ import random ...@@ -3,6 +3,7 @@ import random
import os import os
import sys import sys
from typing import Dict from typing import Dict
import random
import pytest import pytest
import torch import torch
...@@ -19,7 +20,6 @@ def _ensure_initialized(): ...@@ -19,7 +20,6 @@ def _ensure_initialized():
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
......
...@@ -17,25 +17,28 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd ...@@ -17,25 +17,28 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@pytest.mark.parametrize("n_process", [8]) @pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024]) @pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16, 512])
@pytest.mark.parametrize("n_expert", [1]) @pytest.mark.parametrize("n_expert", [1])
@pytest.mark.parametrize("group_sz", [1, 2, 4]) @pytest.mark.parametrize("group_sz", [1, 2, 4])
def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz): @pytest.mark.parametrize("pass_stored", [False, True])
def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz, pass_stored):
_run_distributed('_test_faster_shadow', _run_distributed('_test_faster_shadow',
n_process, n_process,
{ {
'd_model': d_model, 'd_model': d_model,
'batch_size': batch_size, 'batch_size': batch_size,
'n_expert': n_expert 'n_expert': n_expert,
'pass_stored': pass_stored
}, },
script=__file__, script=__file__,
env=dict( env=dict(
FMOE_FASTER_GROUP_SIZE=str(group_sz) FMOE_FASTER_GROUP_SIZE=str(group_sz),
FMOE_FASTER_SHADOW_ENABLE='ON'
) )
) )
def _test_faster_shadow(d_model, batch_size, n_expert): def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
_ensure_initialized() _ensure_initialized()
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -58,15 +61,20 @@ def _test_faster_shadow(d_model, batch_size, n_expert): ...@@ -58,15 +61,20 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
y = m2(x) y = m2(x)
return y return y
stored_models = torch.randint(0, 2, (world_size,)).bool().cuda() if pass_stored:
dist.broadcast(stored_models, 0) stored_models = torch.randint(0, 2, (world_size,)).bool().cuda()
stored_models = stored_models.cpu() dist.broadcast(stored_models, 0)
stored_models = stored_models.cpu()
# if rank == 0: # if rank == 0:
# print('stored models {}'.format(stored_models)) # print('stored models {}'.format(stored_models))
ensure_comm(x1, None) ensure_comm(x1, None)
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1, stored_models=stored_models) if pass_stored:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1,
stored_models=stored_models)
else:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1)
y1.sum().backward() y1.sum().backward()
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2) y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2)
...@@ -82,4 +90,4 @@ if __name__ == '__main__': ...@@ -82,4 +90,4 @@ if __name__ == '__main__':
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else: else:
# test_faster_shadow(8, 16, 16, 1, 2) # test_faster_shadow(8, 16, 16, 1, 2)
_test_faster_shadow(4, 2, 1) _test_faster_shadow(1024, 16, 1, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment