Commit 5680c599 authored by Rich Ho's avatar Rich Ho
Browse files

Merge branch 'master' into laekov/gate

parents 90c4bccf 3c42c892
...@@ -23,20 +23,24 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t); ...@@ -23,20 +23,24 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
// local_exchange // local_exchange
void _assign_pos( void _assign_pos(
torch::Tensor cum_count, torch::Tensor cum_count,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor pos); torch::Tensor pos);
// parallel_linear // parallel_linear
std::vector<torch::Tensor> _linear_forward( std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count); at::optional<torch::Tensor> bias
);
std::vector<torch::Tensor> _linear_backward( std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf, torch::Tensor grad_output_buf,
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor weight, torch::Tensor expert_count,
torch::Tensor expert_count); torch::Tensor weight,
at::optional<torch::Tensor> bias
);
// balancing // balancing
std::vector<torch::Tensor> _limit_by_capacity( std::vector<torch::Tensor> _limit_by_capacity(
......
#include "parallel_linear.h" #include "parallel_linear.cuh"
#include "utils/fmoe_utils.h" #include "utils/fmoe_utils.h"
#include <torch/extension.h> #include <torch/extension.h>
std::vector<torch::Tensor> _linear_forward( std::vector<torch::Tensor> _linear_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count at::optional<torch::Tensor> bias
) { ) {
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef FMOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat); num_expert, in_feat, out_feat);
#endif #endif
auto out_options = torch::TensorOptions()
.device(input_buf.device()) torch::Tensor output;
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options); if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "fmoe_linear_forward", } else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
fmoe_cuda_forward_impl<scalar_t>( fmoe_cuda_linear_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
smgr smgr
); );
})); }));
return {output, }; return {output, };
} }
std::vector<torch::Tensor> _linear_backward( std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf,
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf,
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count,
torch::Tensor expert_count torch::Tensor weight,
at::optional<torch::Tensor> bias
) { ) {
CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf);
CHECK_INPUT(weight);
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef FMOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, " printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n", "out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat); batch_size, num_expert, in_feat, out_feat);
#endif #endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "ffmoe_linear_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
fmoe_cuda_backward_impl<scalar_t>( fmoe_cuda_linear_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(), grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(), grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
...@@ -83,6 +90,6 @@ std::vector<torch::Tensor> _linear_backward( ...@@ -83,6 +90,6 @@ std::vector<torch::Tensor> _linear_backward(
); );
})); }));
return {grad_input_buf, grad_weight}; return {grad_input_buf, grad_weight, grad_bias};
} }
...@@ -2,17 +2,68 @@ ...@@ -2,17 +2,68 @@
#include "utils/cublas_wrapper.h" #include "utils/cublas_wrapper.h"
/*
This function is to be called with one block per each column
*/
template <typename scalar_t> template <typename scalar_t>
void fmoe_cuda_forward_impl( __global__
void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern __shared__ unsigned char my_smem[];
scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);
// normal tid
int tid = threadIdx.x + threadIdx.y * blockDim.x;
// transposed tid for shared memory
int new_tid = threadIdx.y + threadIdx.x * blockDim.y;
// true x value in the matrix
int real_x = threadIdx.x + blockDim.x * blockIdx.x;
int i = real_x + n * threadIdx.y;
const int it = n*blockDim.y;
int offset = it;
float accumulator = 0;
if (threadIdx.y < m && real_x < n) {
// store all the values from this column in a warped way
accumulator = matrix[i];
while (i + offset < n*m) {
accumulator += matrix[i + offset];
offset += it;
}
}
// save column reduction data in a transposed way
sdata[new_tid] = accumulator;
__syncthreads();
for (size_t t= 16; t > 0; t>>=1) {
if (tid < 32 * 32 - 16)
sdata[tid] += sdata[tid + t];
__syncthreads();
}
if (threadIdx.y == 0 && real_x < n)
result[real_x] = sdata[new_tid];
}
template <typename scalar_t>
void fmoe_cuda_linear_forward_impl(
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* output_buf, scalar_t* output_buf,
const bool has_bias,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
...@@ -37,13 +88,15 @@ void fmoe_cuda_forward_impl( ...@@ -37,13 +88,15 @@ void fmoe_cuda_forward_impl(
} }
template <typename scalar_t> template <typename scalar_t>
void fmoe_cuda_backward_impl( void fmoe_cuda_linear_backward_impl(
const scalar_t* grad_output_buf, const scalar_t* grad_output_buf,
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* grad_input_buf, scalar_t* grad_input_buf,
scalar_t* grad_weight, scalar_t* grad_weight,
scalar_t* grad_bias,
const bool has_bias,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
...@@ -51,10 +104,16 @@ void fmoe_cuda_backward_impl( ...@@ -51,10 +104,16 @@ void fmoe_cuda_backward_impl(
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
// bias
dim3 block_threads(32, 32);
dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0, cudaMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat); sizeof(scalar_t) * in_feat * out_feat);
cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
continue; continue;
} }
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
...@@ -84,8 +143,20 @@ void fmoe_cuda_backward_impl( ...@@ -84,8 +143,20 @@ void fmoe_cuda_backward_impl(
&beta, &beta,
grad_weight + i * in_feat * out_feat, in_feat grad_weight + i * in_feat * out_feat, in_feat
)); ));
if (has_bias) {
column_reduce
<<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
(
grad_output_buf + ptr * out_feat,
grad_bias + i * out_feat,
expert_count[i],
out_feat
);
}
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr->sync(num_expert); smgr->sync(num_expert);
} }
...@@ -147,21 +147,25 @@ class MOELinear(Function): ...@@ -147,21 +147,25 @@ class MOELinear(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.linear_forward( (global_output_buf,) = fmoe_cuda.linear_forward(
global_input_buf, weight, fwd_expert_count global_input_buf, fwd_expert_count, weight, bias
) )
variables = (global_input_buf, weight, fwd_expert_count) variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.linear_backward( grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, weight, fwd_expert_count grad_out, input_buf, fwd_expert_count, weight, bias
) )
return grad_inp_buf, grad_weight, None
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function): class MOEGather(Function):
......
...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module): ...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r""" r"""
Call MOE function Call MOE function
""" """
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(
self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x = x + bias
return x return x
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -29,7 +29,7 @@ if __name__ == '__main__': ...@@ -29,7 +29,7 @@ if __name__ == '__main__':
'cuda/local_exchange.cu', 'cuda/local_exchange.cu',
'cuda/balancing.cu', 'cuda/balancing.cu',
'cuda/global_exchange.cpp', 'cuda/global_exchange.cpp',
'cuda/parallel_linear.cpp', 'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp', 'cuda/fmoe_cuda.cpp',
], ],
extra_compile_args={ extra_compile_args={
......
...@@ -11,7 +11,7 @@ from test_numerical import test_fmoe_linear as _test_fmoe_linear ...@@ -11,7 +11,7 @@ from test_numerical import test_fmoe_linear as _test_fmoe_linear
from test_numerical import _test_fmoe_local_ddp from test_numerical import _test_fmoe_local_ddp
def _run_distributed(func, world_size, args: Dict): def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size: if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU") pytest.skip("No enough GPU")
import subprocess import subprocess
...@@ -25,7 +25,7 @@ def _run_distributed(func, world_size, args: Dict): ...@@ -25,7 +25,7 @@ def _run_distributed(func, world_size, args: Dict):
for i in range(world_size): for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i) os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen( p = subprocess.Popen(
[sys.executable, __file__, func, json.dumps(args)], stdout=subprocess.PIPE [sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
) )
ps.append(p) ps.append(p)
...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict): ...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear_distributed( def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
): ):
_run_distributed( _run_distributed(
"_test_fmoe_linear", "_test_fmoe_linear",
...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed( ...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model": d_model, "d_model": d_model,
"d_hidden": d_hidden, "d_hidden": d_hidden,
"mp_size": mp_size, "mp_size": mp_size,
"data_type": data_type
}, },
) )
...@@ -120,5 +122,6 @@ if __name__ == "__main__": ...@@ -120,5 +122,6 @@ if __name__ == "__main__":
else: else:
test_fmoe_local_ddp(mp_size=2) test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed( test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2 num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
) )
import pytest import pytest
import os import os
import sys
import json
import math import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fmoe.gates import GShardGate, SwitchGate from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed
def _ensure_initialized(): def _ensure_initialized():
...@@ -16,14 +21,27 @@ def _ensure_initialized(): ...@@ -16,14 +21,27 @@ def _ensure_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
@pytest.mark.parametrize("d_model", [8, 1024]) @pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16, 4096]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4, 16]) @pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, .5, 1.1]) @pytest.mark.parametrize("cap", [.1, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap): def test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized() if 1 * n_expert < 2:
if dist.get_world_size() * n_expert < 2:
pytest.skip("No enough experts") pytest.skip("No enough experts")
_run_distributed('_test_gshard_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
def _test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
gate = GShardGate(d_model, n_expert, dist.get_world_size(), gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
...@@ -37,11 +55,24 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -37,11 +55,24 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert(i <= real_cap) assert(i <= real_cap)
@pytest.mark.parametrize("d_model", [8, 1024]) @pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16, 4096]) @pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16]) @pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1]) @pytest.mark.parametrize("cap", [.1, .8])
def test_switch_gate(d_model, batch_size, n_expert, cap): def test_switch_gate(d_model, batch_size, n_expert, cap):
_run_distributed('_test_switch_gate',
1,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'cap': cap
},
script=__file__
)
def _test_switch_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized() _ensure_initialized()
gate = SwitchGate(d_model, n_expert, dist.get_world_size(), gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
...@@ -57,6 +88,11 @@ def test_switch_gate(d_model, batch_size, n_expert, cap): ...@@ -57,6 +88,11 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
if __name__ == '__main__': if __name__ == '__main__':
_ensure_initialized() if len(sys.argv) >= 3:
test_gshard_gate(4096, 1024, 4, .2) args = json.loads(sys.argv[2])
# test_switch_gate(4096, 1024, 4, .2) locals()[sys.argv[1]](**args)
else:
_ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
test_gshard_gate(8, 16, 1, .1)
# test_switch_gate(4096, 1024, 4, .2)
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
from fmoe.functions import MOEGather, MOEScatter, count_by_gate from fmoe.functions import MOEGather, MOEScatter, count_by_gate
from test_numerical import _assert_numercial from test_numerical import _assert_numerical
@pytest.mark.parametrize("n_expert", [1, 4, 8]) @pytest.mark.parametrize("n_expert", [1, 4, 8])
@pytest.mark.parametrize("topk", [1, 2]) @pytest.mark.parametrize("topk", [1, 2])
...@@ -30,10 +30,10 @@ def test_scatter(n_expert, topk, batch_size, d_model, world_size): ...@@ -30,10 +30,10 @@ def test_scatter(n_expert, topk, batch_size, d_model, world_size):
inp_raw = inp.data.clone() inp_raw = inp.data.clone()
out_raw = torch.empty(pos.shape[0], d_model, out_raw = torch.empty(pos.shape[0], d_model,
device=inp.device, dtype=inp.dtype) device=inp.device, dtype=inp.dtype)
out_raw.sum().backward() # out_raw.sum().backward()
for i, f in enumerate(pos.cpu()): for i, f in enumerate(pos.cpu()):
out_raw[i] = inp[f % batch_size] out_raw[i] = inp[f % batch_size]
_assert_numercial(['out'], [out], [out_raw], 0) _assert_numerical(['out'], [out], [out_raw], 0)
# TODO: check grad # TODO: check grad
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,11 +17,13 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert ...@@ -17,11 +17,13 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def _perform_forward( def _perform_forward(
moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
): ):
moe.zero_grad() moe.zero_grad()
moe_raw.zero_grad() moe_raw.zero_grad()
inp = torch.rand(batch_size, d_model).cuda()
inp = torch.rand(batch_size, d_model).type(data_type).cuda()
if mp_group is not None: if mp_group is not None:
group_sender = rank // mp_group.size() * mp_group.size() group_sender = rank // mp_group.size() * mp_group.size()
torch.distributed.broadcast(inp, group_sender, group=mp_group) torch.distributed.broadcast(inp, group_sender, group=mp_group)
...@@ -46,15 +48,17 @@ def _perform_forward( ...@@ -46,15 +48,17 @@ def _perform_forward(
return moe_out, raw_out, inp.grad, inp_raw.grad return moe_out, raw_out, inp.grad, inp_raw.grad
def _assert_numercial(names, moe_out_list, raw_out_list, rank): def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list): for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().sum() err = (mo - ro).abs().max()
print("Rank {} {} abs err {}".format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3: if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n") sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo)) sys.stderr.write("{}\n".format(mo))
sys.stderr.write(f"=========== {name} raw out ==============\n") sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro)) sys.stderr.write("{}\n".format(ro))
sys.stderr.write(f"=========== {name} diff ==============\n")
sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
assert False assert False
...@@ -87,6 +91,7 @@ class MyMoE(FMoE): ...@@ -87,6 +91,7 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear( def test_fmoe_linear(
num_expert, num_expert,
top_k, top_k,
...@@ -98,6 +103,7 @@ def test_fmoe_linear( ...@@ -98,6 +103,7 @@ def test_fmoe_linear(
mp_group, mp_group,
dp_group, dp_group,
world_group, world_group,
data_type,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
...@@ -105,7 +111,7 @@ def test_fmoe_linear( ...@@ -105,7 +111,7 @@ def test_fmoe_linear(
moe = MyMoE( moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).cuda() ).type(data_type).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -114,7 +120,7 @@ def test_fmoe_linear( ...@@ -114,7 +120,7 @@ def test_fmoe_linear(
d_hidden=d_hidden, d_hidden=d_hidden,
world_size=world_size, world_size=world_size,
top_k=top_k, top_k=top_k,
).cuda() ).type(data_type).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
...@@ -145,7 +151,7 @@ def test_fmoe_linear( ...@@ -145,7 +151,7 @@ def test_fmoe_linear(
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0) moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward( moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
) )
moe_out_list = ( moe_out_list = (
...@@ -195,7 +201,10 @@ def test_fmoe_linear( ...@@ -195,7 +201,10 @@ def test_fmoe_linear(
"h4toh bias grad", "h4toh bias grad",
] ]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
_assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -296,7 +305,7 @@ def test_fmoe( ...@@ -296,7 +305,7 @@ def test_fmoe(
raw_out_list = [raw_out, raw_grad, raw_grad_in] raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"] names = ["forward", "backward", "grad_in"]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numerical(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module): class MyModule(nn.Module):
...@@ -372,7 +381,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -372,7 +381,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names = ["mp grad", "dp grad", "wp grad"] names = ["mp grad", "dp grad", "wp grad"]
_assert_numercial(names, ddp_out_list, raw_out_list, rank) _assert_numerical(names, ddp_out_list, raw_out_list, rank)
if __name__ == "__main__": if __name__ == "__main__":
......
import sys
import torch import torch
from fmoe.layers import _fmoe_general_global_forward from fmoe.layers import _fmoe_general_global_forward
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
from test_ddp import _run_distributed
class ConstantGate(torch.nn.Module): class ConstantGate(torch.nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=1): def __init__(self, d_model, num_expert, world_size, top_k=1):
...@@ -16,12 +19,34 @@ class ConstantGate(torch.nn.Module): ...@@ -16,12 +19,34 @@ class ConstantGate(torch.nn.Module):
def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_fwd',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def _test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda() inp = torch.rand(batch_size, d_hidden).cuda()
gate = torch.zeros(batch_size, dtype=torch.int64).cuda() gate = torch.zeros(batch_size, dtype=torch.int64).cuda()
x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert, x = _fmoe_general_global_forward(inp, gate, lambda x, y: x, num_expert,
world_size) world_size)
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
_run_distributed('_test_zero_transformer',
1,
{
'num_expert': num_expert,
'batch_size': batch_size,
'd_hidden': d_hidden
},
script=__file__
)
def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
inp = torch.rand(batch_size, d_hidden).cuda() inp = torch.rand(batch_size, d_hidden).cuda()
model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size, model = FMoETransformerMLP(num_expert, d_hidden, d_hidden * 4, world_size,
...@@ -30,9 +55,13 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1): ...@@ -30,9 +55,13 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend="nccl") if len(sys.argv) >= 3:
torch.cuda.set_device(torch.distributed.get_rank()) args = json.loads(sys.argv[2])
# test_zero_fwd(world_size=torch.distributed.get_world_size()) torch.distributed.init_process_group(backend="nccl")
test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024, args['world_size'] = torch.distributed.get_world_size()
world_size=torch.distributed.get_world_size()) locals()[sys.argv[1]](**args)
print('done') else:
# test_zero_fwd(world_size=torch.distributed.get_world_size())
test_zero_transformer(num_expert=16, batch_size=4096, d_hidden=1024,
world_size=1)
print('done')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment