"...text-generation-inference.git" did not exist on "f848decee615ee10b78510b62036021a075dbf7b"
Commit 49b5b5d6 authored by Rick Ho's avatar Rick Ho
Browse files

test for faster gate

parent 6e1fcca1
import json import json
import random
import os import os
import sys import sys
from typing import Dict from typing import Dict
...@@ -13,30 +14,34 @@ from test_numerical import _test_fmoe_local_ddp ...@@ -13,30 +14,34 @@ from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized(): def _ensure_initialized():
if not dist.is_initialized(): if 'RANK' not in os.environ:
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0") os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__): def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
if torch.cuda.device_count() < world_size: device_count = torch.cuda.device_count()
pytest.skip("No enough GPU") if device_count < world_size:
pytest.skip("No enough GPU, only {} found".format(device_count))
import subprocess import subprocess
import os import os
ps = [] ps = []
os.environ["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666" env["MASTER_PORT"] = str(random.randint(50000, 60000))
os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size) env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
for i in range(world_size): for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i) env["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen( p = subprocess.Popen(
[sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE [sys.executable, script, func, json.dumps(args)],
stdout=subprocess.PIPE,
env=env
) )
ps.append(p) ps.append(p)
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates.faster_gate import FasterGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("gpu_per_node", [2, 4, 8])
@pytest.mark.parametrize("frac", [.2])
def test_faster_gate(n_process, d_model, batch_size, n_expert, gpu_per_node, frac):
_run_distributed('_test_faster_gate',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'gpu_per_node': gpu_per_node,
'frac': frac
},
script=__file__,
env=dict(
FMOE_TOPO_GPUS_PER_NODE=str(gpu_per_node),
FMOE_TOPO_OUTGOING_FRACTION=str(frac)
)
)
def _test_faster_gate(d_model, batch_size, n_expert, gpu_per_node, frac):
_ensure_initialized()
rank = dist.get_rank()
node_rank = rank // gpu_per_node
gate = FasterGate(d_model, n_expert, dist.get_world_size(), node_rank).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
cnto = 0
idxs = topk_idx[:, 0].cpu().view(-1).numpy()
for v in idxs:
assert(v != -1)
if v // n_expert // gpu_per_node != rank // gpu_per_node:
cnto += 1
assert(cnto <= math.ceil(batch_size * frac))
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
test_faster_gate(8, 1024, 16, 1, 2, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment