Commit ed69591a authored by Rick Ho's avatar Rick Ho
Browse files

update fmoe files

parent 0c47afbb
## Introduction
This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts:
- `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine.
- `*base.sh` are for the base models which can be run on a few GPUs.
The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments.
## Prerequisite
- Pytorch 0.4: `conda install pytorch torchvision -c pytorch`
## Data Prepration
`bash getdata.sh`
## Training and Evaluation
#### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL
- Make sure the machine have **4 GPUs**, each with **at least 11G memory**
- Training
`bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR`
- Evaluation
`bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR`
#### Replicate the "PPL = 24.03" result on `wikitext-103` with Transformer-XL
- Make sure the machine have **4 GPUs**, each with **at least 11G memory**
- Training
`bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR`
- Evaluation
`bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR`
#### Other options:
- `--batch_chunk`: this option allows one to trade speed for memory. For `batch_chunk > 1`, the program will split each training batch into `batch_chunk` sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by `batch_chunk`. Hence, the memory usage will propertionally lower while the computation time will inversely higher.
- `--div_val`: when using adaptive softmax and embedding, the embedding dimension is divided by `div_val` from bin $i$ to bin $i+1$. This saves both GPU memory and the parameter budget.
- `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling.
- Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/).
- To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts.
- To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`.
#### Other datasets:
- `Text8` character-level language modeling: check out `run_text8_base.sh`
- `lm1b` word-level language modeling: check out `run_lm1b_base.sh`
from .moe import FMoE, BruteForceMoE
import math
from torch import nn
import torch
from .moe_function import moe
class FMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=None):
super(MOELayer, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.world_size = world_size
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
return moe(inp, gate.int(), self.weight, self.world_size)
class BruteForceMoE(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024,
world_size=0):
super(MOELayer_raw, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter(
torch.Tensor(num_expert * world_size, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat,
out_features=self.out_feat)
# print(linear.weight.shape)
self.weight.data[i] = linear.weight.data
def forward(self, inp, gate):
gate_long = gate.long()
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat))
for i in range(batch_size):
x[i] = inp[i] @ self.weight[gate_long[i]].t()
return x
import torch
from torch.autograd import Function
import moe_cuda
class MOELocal(Function):
@staticmethod
def forward(ctx, inp, gate, weight):
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
variables = [input_buf, gate, weight, expert_count, pos]
ctx.save_for_backward(*variables)
return output[0]
@staticmethod
def backward(ctx, grad_out):
input_buf, gate, weight, expert_count, pos = ctx.saved_tensors
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
grad_inp_buf, grad_weight = moe_cuda.backward(
grad_out_buf, input_buf, weight, expert_count)
grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)
return grad_inp, None, grad_weight
class MOEGlobal(Function):
@staticmethod
def forward(ctx, inp, gate, weight, world_size):
num_expert = weight.shape[0]
local_expert_count, pos = moe_cuda.expert_count(gate,
world_size * num_expert)
global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
local_expert_count, num_expert, world_size)
fwd_batch_size = int(fwd_expert_count.sum().item())
local_input_buf, = moe_cuda.local_scatter(inp, pos)
global_input_buf, = moe_cuda.global_scatter(local_input_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
global_output_buf, = moe_cuda.forward(global_input_buf, weight,
fwd_expert_count)
local_output_buf, = moe_cuda.global_gather(global_output_buf,
local_expert_count, global_expert_count,
inp.shape[0], world_size)
output, = moe_cuda.local_gather(local_output_buf, pos)
variables = (global_input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos)
ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
(input_buf, gate, weight,
local_expert_count, global_expert_count, fwd_expert_count,
pos) = ctx.saved_tensors
num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
local_expert_count, global_expert_count,
fwd_batch_size, world_size)
grad_inp_buf, grad_weight = moe_cuda.backward(
global_grad_out_buf, input_buf, weight, fwd_expert_count)
local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
local_expert_count, global_expert_count,
local_batch_size, world_size)
grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)
return grad_inp, None, grad_weight, None
def moe(inp, gate, weight, world_size):
if world_size is not None and world_size > 1:
return MOEGlobal.apply(inp, gate, weight, world_size)
else:
return MOELocal.apply(inp, gate, weight)
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
CUDA_HELPER = os.environ.get('CUDA_HELPER', '/usr/local/cuda/samples/common/inc')
cxx_flags = [
'-I{}'.format(CUDA_HELPER)
]
if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL')
setup(
name='moe_cuda',
ext_modules=[
CUDAExtension(
name='moe_cuda',
sources=[
'cuda/moe.cpp',
'cuda/cuda_stream_manager.cpp',
'cuda/moe_cuda_kernel.cu',
],
extra_compile_args={
'cxx': cxx_flags,
'nvcc': cxx_flags
}
)
],
cmdclass={
'build_ext': BuildExtension
})
#!/bin/bash
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
fi
export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python3 moe_test.py 2>logs/$OMPI_COMM_WORLD_RANK.log
else
python3 $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi
from moe import MOELayer, MOELayer_raw
import torch
from torch import nn
import time
import sys
dev_name_default = 'cuda:0'
def perf():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
if len(sys.argv) == 6:
batch_size = int(sys.argv[2])
in_feat = int(sys.argv[3])
out_feat = int(sys.argv[4])
num_expert = int(sys.argv[5])
else:
batch_size = 4096
in_feat = 1024
out_feat = 4096
num_expert = 4
if torch.distributed.get_rank() == 0:
print('Performance test case bs {} {}x{} ne {}'.format(batch_size,
in_feat, out_feat, num_expert))
if torch.distributed.get_world_size() > 1:
dev_name = 'cuda'
else:
dev_name = dev_name_default
inp = torch.rand(batch_size, in_feat).cuda(dev_name)
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda(dev_name)
moe.train()
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
o = moe(inp, gate)
n_runs = 16
tott = 0.
backt = 0.
maxt = 0.
sqtot = 0.
for i in range(n_runs):
gate = torch.randint(low=0,
high=num_expert * torch.distributed.get_world_size(),
size=(batch_size, ), requires_grad=False).int().cuda(dev_name)
ts = time.time()
o = moe(inp, gate)
te = time.time()
loss = o.sum()
bts = time.time()
loss.backward()
bte = time.time()
tott += te - ts
sqtot += (te - ts)**2
maxt = max(maxt, te - ts)
backt = bte - bts
gflops = 2e-9 * n_runs * in_feat * out_feat * batch_size / tott
print('Time mean/max/stdev/back {:.3f} {:.3f} {:.3f} {:.3f} ms, {:.3f} GFLOPs'.format(
tott * 1e3 / n_runs, maxt * 1e3,
(sqtot / n_runs - (tott / n_runs)**2) * 1e3 / n_runs,
backt * 1e3 / n_runs, gflops))
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = (linear(inp))
output = moe(x, gate)
# print('ooutput', torch.distributed.get_rank(), output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
torch.manual_seed(42 + torch.distributed.get_rank())
torch.cuda.manual_seed(42 + torch.distributed.get_rank())
batch_size = 4
num_expert = 2
in_feat = 6
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda()
if world_size > 1:
moe = MOELayer(num_expert, in_feat, out_feat, world_size).cuda()
else:
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat, world_size).cuda()
if world_size == 1:
moe_raw.weight.data = moe.weight.data.clone()
else:
weight_array = [torch.empty_like(moe.weight.data).cpu()
for _ in range(world_size)]
torch.distributed.all_gather(weight_array, moe.weight.data.cpu())
moe_raw.weight.data = torch.cat(weight_array, dim=0).cuda()
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0,
high=num_expert * world_size,
size=(batch_size,),
requires_grad=False).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
if world_size > 1:
rank = torch.distributed.get_rank()
ou, wg, lwg, lbg = raw_out
wg = wg.cpu()
torch.distributed.all_reduce(wg)
wg = wg[rank * num_expert:(rank + 1)* num_expert]
raw_out = ou, wg.cuda(), lwg, lbg
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
def test_dp():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 6
num_expert = 4
in_feat = 2
out_feat = 3
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda()
linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = linear_dp(inp)
print("successful!")
print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
for i in range(5):
output = moe_dp(inp, gate)
if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
world_size = torch.distributed.get_world_size()
if len(sys.argv) >= 2:
task = sys.argv[1]
print('Specificed task {}'.format(task))
if task == 'correctness':
test()
elif task == 'dp':
test_dp()
elif task == 'performance':
perf()
else:
test()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment