"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "a490aafa3671da1b6b2be6cff4568913fcb1732c"
Commit 46c3722d authored by Rick Ho's avatar Rick Ho
Browse files

forward pass test

parent 49b5b5d6
#ifdef FMOE_USE_NCCL
#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "smart_schedule.h"
long pipeline_gran = -1;
torch::Tensor _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size,
long n_workers,
py::function forward_fn) {
if (pipeline_gran == -1) {
char* p = getenv("FMOE_FASTER_GROUP_SIZE");
if (p) {
pipeline_gran = atoi(p);
} else {
pipeline_gran = 4;
}
}
auto smgr = getCudaStreamManager(input_buf.device().index());
int rank;
NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto d_model = input_buf.size(1);
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_forward", ([&] {
fmoe_cuda_fused_forward_impl(
forward_fn,
input_buf.device(),
input_buf.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers,
pipeline_gran, smgr);
}));
return output_buf;
}
/*
std::vector<torch::Tensor> _fused_backward(
torch::Tensor input_buf,
std::vector<std::vector<std::vector<torch::Tensor>>> params,
torch::Tensor middle_buf,
torch::Tensor output_buf,
torch::Tensor grad_out,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor inp,
torch::Tensor stored_models,
long global_batch_size,
long buf_batch_size,
long n_workers, bool has_bias) {
const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(input_buf.device().index());
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
const auto d_hidden = params[rank][0][0].size(1);
const auto d_model = params[rank][0][0].size(2);
auto global_grad_out = input_buf.new_zeros({global_batch_size, d_model});
auto grad_middle = input_buf.new_zeros({global_batch_size, d_hidden});
auto global_grad_in = input_buf.new_zeros({global_batch_size, d_model});
auto grad_in = input_buf.new_zeros({buf_batch_size, d_model});
for (auto node : params)
for (auto expert : node)
for (int i = 0; i < expert.size(); i++) {
// create the respective gradient of each tensor
CHECK_INPUT(expert[i]);
if (expert[i].grad().defined()) {
CHECK_INPUT(expert[i].grad());
continue;
}
expert[i].mutable_grad() = input_buf.new_zeros(expert[i].sizes());
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_backward", ([&] {
fmoe_cuda_fused_backward_impl(
input_buf.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
params,
middle_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(),
global_grad_out.data_ptr<scalar_t>(),
global_grad_in.data_ptr<scalar_t>(),
grad_middle.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, d_hidden, num_expert, rank, n_workers, has_bias,
pipeline_gran, smgr);
}));
return {grad_in,};
}
*/
#endif
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include "../stream_manager.h"
template<typename scalar_t>
void _exchange_with(
const scalar_t* sendbuf, size_t sendcount, int t_send,
scalar_t* recvbuf, size_t recvcount, int t_recv,
long d_model,
cudaStream_t stream, ncclComm_t comm) {
if (sendcount) {
ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
ncclChar, t_send , comm, stream);
}
if (recvcount) {
ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
ncclChar, t_recv, comm, stream);
}
}
#define GEN_BASE(_step) \
long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
int idx_send = ei + rank_send * num_expert; \
int idx_recv = ei + rank_recv * num_expert; \
int gidx_send = ei * world_size + rank_send; \
int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert;
void _compute_ptrs(long num_expert, long rank, long world_size,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
int *local_ptr,
int *global_ptr,
int *local_global_ptr) {
local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) {
local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];
local_global_ptr[i + 1] = local_global_ptr[i];
// if model fetched, add local tokens
if (stored_models[i]){
local_global_ptr[i + 1] += local_expert_count[i];
}
auto expert_idx = i % num_expert;
auto worker_idx = i / num_expert;
auto gp_idx = expert_idx * world_size + worker_idx;
// if local model wasn't fetched, receive global tokens
if (stored_models[rank * num_expert + expert_idx]) {
global_ptr[gp_idx + 1] = 0;
} else {
global_ptr[gp_idx + 1] = global_expert_count[i];
}
}
global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) {
global_ptr[i + 1] += global_ptr[i];
}
}
template<typename scalar_t>
void _compute_forward(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model) {
auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value)
.device(device)
.requires_grad(true);
auto inp = torch::from_blob(inp_buf + offset * d_model,
{micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
fn(inp, oup, step);
}
template<typename scalar_t>
void _compute_backward(py::function fn,
scalar_t* inp_buf, scalar_t* out_buf,
long* local_expert_count, long* global_expert_count,
int ei, long offset, long micro_batch_size, long d_model) {
}
template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
py::function forward_fn,
c10::Device device,
const scalar_t* input_buf,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) {
pipeline_gran = world_size;
}
long n_groups = world_size / pipeline_gran;
long group_rank = rank / pipeline_gran;
cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
}
for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
_exchange_with(input_buf + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
}
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
_compute_forward(forward_fn, device,
global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model);
}
auto stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEventRecord(output_ready[step], stream);
}
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
_exchange_with(global_output_buf + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_forward(
input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden,
smgr->stream(stream), smgr->handle(stream));
}
}*/
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
checkCudaErrors(cudaGetLastError());
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
}
delete [] input_ready;
delete [] output_ready;
}
template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
py::function backward_fn,
const scalar_t* input_buf,
const scalar_t* output_buf,
const scalar_t* grad_out,
scalar_t* global_grad_out,
scalar_t* global_grad_in,
scalar_t* grad_in,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model, long d_hidden,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) {
pipeline_gran = world_size;
}
long n_groups = world_size / pipeline_gran;
long group_rank = rank / pipeline_gran;
cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
}
for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
_exchange_with(grad_out + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
}
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
_compute_backward(backward_fn,
input_buf, output_buf, global_grad_out,
global_grad_in,
ei,
offset, micro_batch_size);
}
// TODO: get pytorch's compute stream
}
for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
_exchange_with(global_grad_in + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
checkCudaErrors(cudaGetLastError());
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_backward(
original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
smgr->stream(stream), smgr->handle(stream));
}
}
*/
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
checkCudaErrors(cudaGetLastError());
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
}
delete [] input_ready;
delete [] output_ready;
}
#endif // SMART_SCHEDULE_H
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/extension.h> #include <torch/extension.h>
// global_exchange // global_exchange
...@@ -56,6 +57,15 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -56,6 +57,15 @@ std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity_tensor, torch::Tensor gate_idx, torch::Tensor capacity_tensor,
long n_expert, long n_worker, long bias); long n_expert, long n_worker, long bias);
// smart scheduling
torch::Tensor _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size, long n_workers,
py::function forward_fn);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)"); m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)");
...@@ -63,6 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -63,6 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)"); m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm"); m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)"); m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
#endif #endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)"); m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef FMOE_USE_NCCL
#include <nccl.h>
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
long in_feat, long out_feat,
long num_expert, long world_size,
CudaStreamManager* smgr) {
int ptr = 0;
int send_ptr = 0;
int recv_ptr = 0;
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
scalar_t alpha = 1, beta = 0;
for (int i = 0; i < num_expert; ++i) {
int expert_count = 0;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
global_input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
recv_ptr += global_expert_count[idx];
expert_count += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
checkCudaErrors(cublasXgemm(
smgr->handle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count, in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
global_input_buf + ptr * in_feat, in_feat,
&beta,
global_output_buf + out_feat * ptr, out_feat
));
ptr += expert_count;
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
global_output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(i)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(num_expert);
}
std::vector<torch::Tensor> moe_cuda_global_fused_forward(
torch::Tensor input_buf,
torch::Tensor weight,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) {
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
auto smgr = getCudaStreamManager(input_buf.device().index());
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
in_feat, out_feat, num_expert, n_workers,
smgr);
}));
return {output_buf, global_input_buf};
}
#endif
r"""
The smart schedule proposed in FasterMoE.
"""
import torch
from torch.autograd.function import Function
from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather
import fmoe_cuda as fmoe_native
class MoEForward(Function):
@staticmethod
def forward(
ctx,
expert_fn,
inp, # models,
pos_s, pos_g,
local_expert_count, global_expert_count,
stored_models,
fwd_batch_size, out_batch_size,
world_size):
local_input_buf = _local_scatter(inp, pos_s)
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx.gibs = [None] * world_size
ctx.gobs = [None] * world_size
def _expert_forward(x, y, idx):
x = x.data
x.requires_grad = True
y0 = expert_fn(x, [x.shape[0]])
ctx.gibs[idx] = x
ctx.gobs[idx] = y0
y.copy_(y0)
local_output_buf = fmoe_native.smart_sch_forward(
local_input_buf,
local_expert_count, global_expert_count,
stored_models, fwd_batch_size,
world_size, _expert_forward)
out = _local_gather(local_output_buf, pos_g, out_batch_size,
maybe_overlap=False)
variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
ctx.save_for_backward(*variables)
return out
@staticmethod
def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count,
gib, gmb, gob, stored_models) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
def _expert_backward(grad, idx):
y = ctx.gobs[idx]
torch.autograd.backward([y], [grad])
x = ctx.gibs[idx]
return x.grad
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward(
gib, gmb, gob, grad_out_buf,
local_expert_count, global_expert_count,
stored_models,
fwd_batch_size, pos_s.shape[0],
world_size, _expert_backward)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, grad_in, None, None, None, None, None, None, None, None)
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
# TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process
assert(n_expert == 1)
(
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(gate, n_expert, world_size)
# TODO: Expert shadowing is to be supported. Currently using all 0s
stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
out_batch_size = inp.shape[0] * topk
return MoEForward.apply(expert_fn, inp,
torch.div(pos, topk, rounding_mode='floor'), pos,
local_expert_count, global_expert_count, stored_models,
fwd_batch_size, out_batch_size, world_size)
...@@ -2,6 +2,7 @@ r""" ...@@ -2,6 +2,7 @@ r"""
FMoE core layer FMoE core layer
""" """
import tree import tree
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -46,7 +47,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -46,7 +47,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
def scatter_func(tensor): def scatter_func(tensor):
return MOEScatter.apply( return MOEScatter.apply(
tensor, tensor,
pos // topk, torch.div(pos, topk, rounding_mode='floor'),
local_expert_count, local_expert_count,
global_expert_count, global_expert_count,
fwd_batch_size, fwd_batch_size,
...@@ -75,6 +76,10 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -75,6 +76,10 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
return outp return outp
if os.environ.get('FMOE_FASTER_SCHEDULE_ENABLE', '0') in ['1', 'ON']:
from .fastermoe.schedule import _fmoe_general_global_forward
class FMoE(nn.Module): class FMoE(nn.Module):
r""" r"""
A general moe implementation that supports an arbitrary module as the A general moe implementation that supports an arbitrary module as the
...@@ -149,10 +154,12 @@ class FMoE(nn.Module): ...@@ -149,10 +154,12 @@ class FMoE(nn.Module):
""" """
if self.experts_fused: if self.experts_fused:
return self.experts(inp, fwd_expert_count) return self.experts(inp, fwd_expert_count)
if isinstance(fwd_expert_count, torch.Tensor):
fwd_expert_count = fwd_expert_count.cpu().numpy()
outputs = [] outputs = []
base_idx = 0 base_idx = 0
for i in range(self.num_expert): for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item() batch_size = fwd_expert_count[i]
inp_slice = inp[base_idx : base_idx + batch_size] inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice)) outputs.append(self.experts[i](inp_slice))
base_idx += batch_size base_idx += batch_size
......
...@@ -43,7 +43,7 @@ if __name__ == '__main__': ...@@ -43,7 +43,7 @@ if __name__ == '__main__':
author_email='hja20@mails.tsinghua.edu.cn', author_email='hja20@mails.tsinghua.edu.cn',
license='Apache-2', license='Apache-2',
url='https://github.com/laekov/fastmoe', url='https://github.com/laekov/fastmoe',
packages=['fmoe', 'fmoe.megatron', 'fmoe.gates'], packages=['fmoe', 'fmoe.megatron', 'fmoe.gates', 'fmoe.fastermoe'],
ext_modules=[ ext_modules=[
CUDAExtension( CUDAExtension(
name='fmoe_cuda', name='fmoe_cuda',
...@@ -54,6 +54,7 @@ if __name__ == '__main__': ...@@ -54,6 +54,7 @@ if __name__ == '__main__':
'cuda/global_exchange.cpp', 'cuda/global_exchange.cpp',
'cuda/parallel_linear.cu', 'cuda/parallel_linear.cu',
'cuda/fmoe_cuda.cpp', 'cuda/fmoe_cuda.cpp',
'cuda/fastermoe/smart_schedule.cpp',
], ],
define_macros=define_macros, define_macros=define_macros,
extra_compile_args={ extra_compile_args={
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed
from test_numerical import _assert_numerical
from fmoe.fastermoe.schedule import _fmoe_general_global_forward as smart_fwd
from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1])
def test_faster_schedule(n_process, d_model, batch_size, n_expert):
_run_distributed('_test_faster_schedule',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert
},
script=__file__,
env=dict()
)
def _test_faster_schedule(d_model, batch_size, n_expert):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
x = torch.rand(batch_size, d_model).cuda()
x.requires_grad = True
topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda()
m = torch.nn.Linear(d_model, d_model).cuda()
def expert_fn(x, fec):
y = m(x)
return y
ensure_comm(x, None)
y = smart_fwd(x, topk_idx, expert_fn, n_expert, world_size)
z = naive_fwd(x, topk_idx, expert_fn, n_expert, world_size)
_assert_numerical(['out'], [y], [z], rank)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
# test_faster_schedule(8, 16, 16, 1)
_test_faster_schedule(4, 2, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment