"src/vscode:/vscode.git/clone" did not exist on "eadf0e2555cfa19b033e02de53553f71ac33536f"
Commit 5e9bb2e9 authored by Rick Ho's avatar Rick Ho
Browse files

do not require comm in non-nccl environment

parent 8f1f2ca5
#!/usr/bin/env python
# encoding: utf-8
# File Name: topk.py
# Author: Jiezhong Qiu
# Create Time: 2020/11/24 20:23
# TODO:
import torch
import time
from mem_transformer import my_topk
output = torch.rand(16, 512, 512).cuda()
torch.cuda.synchronize()
start = time.time()
_, pred = output.topk(k=1, dim=-1, largest=True, sorted=False)
torch.cuda.synchronize()
print("torch.top1 Time :{}".format(time.time() - start))
torch.cuda.synchronize()
start = time.time()
_, pred_ = my_topk(output, k=1, inplace=True)
torch.cuda.synchronize()
print("my top1 Time :{}".format(time.time() - start))
torch.cuda.synchronize()
start = time.time()
_, pred = output.topk(k=2, dim=-1, largest=True, sorted=False)
torch.cuda.synchronize()
print("torch.top2 Time :{}".format(time.time() - start))
torch.cuda.synchronize()
start = time.time()
_, pred_ = my_topk(output, k=2, inplace=True)
torch.cuda.synchronize()
print("my top2 Time :{}".format(time.time() - start))
......@@ -21,9 +21,9 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
if comm is None:
comm = get_torch_default_comm()
if world_size > 1:
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, gate)
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment