Commit 49b5b5d6 authored by Rick Ho's avatar Rick Ho
Browse files

test for faster gate

parent 6e1fcca1
import json
import random
import os
import sys
from typing import Dict
......@@ -13,30 +14,34 @@ from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized():
if not dist.is_initialized():
if 'RANK' not in os.environ:
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
device_count = torch.cuda.device_count()
if device_count < world_size:
pytest.skip("No enough GPU, only {} found".format(device_count))
import subprocess
import os
ps = []
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000))
env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
env["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen(
[sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
[sys.executable, script, func, json.dumps(args)],
stdout=subprocess.PIPE,
env=env
)
ps.append(p)
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates.faster_gate import FasterGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("gpu_per_node", [2, 4, 8])
@pytest.mark.parametrize("frac", [.2])
def test_faster_gate(n_process, d_model, batch_size, n_expert, gpu_per_node, frac):
_run_distributed('_test_faster_gate',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'gpu_per_node': gpu_per_node,
'frac': frac
},
script=__file__,
env=dict(
FMOE_TOPO_GPUS_PER_NODE=str(gpu_per_node),
FMOE_TOPO_OUTGOING_FRACTION=str(frac)
)
)
def _test_faster_gate(d_model, batch_size, n_expert, gpu_per_node, frac):
_ensure_initialized()
rank = dist.get_rank()
node_rank = rank // gpu_per_node
gate = FasterGate(d_model, n_expert, dist.get_world_size(), node_rank).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
cnto = 0
idxs = topk_idx[:, 0].cpu().view(-1).numpy()
for v in idxs:
assert(v != -1)
if v // n_expert // gpu_per_node != rank // gpu_per_node:
cnto += 1
assert(cnto <= math.ceil(batch_size * frac))
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
test_faster_gate(8, 1024, 16, 1, 2, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment