Unverified Commit a6a8c4a7 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #103 from laekov/faster-expert-shadow

FasterMoE Expert Shadowing
parents 0c308313 91a5e794
...@@ -19,14 +19,63 @@ void setSmartSchEnabled(int s) { ...@@ -19,14 +19,63 @@ void setSmartSchEnabled(int s) {
smart_sch_enabled = s; smart_sch_enabled = s;
} }
inline ncclDataType_t getNcclDataType(at::ScalarType t) {
switch (t) {
case at::kChar: return ncclInt8;
case at::kByte: return ncclUint8;
case at::kFloat: return ncclFloat;
case at::kDouble: return ncclDouble;
case at::kInt: return ncclInt32;
case at::kLong: return ncclInt64;
case at::kHalf: return ncclHalf;
case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
case at::kBFloat16: return ncclBfloat16;
#endif
default: return ncclChar;
}
}
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size) {
auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream);
cudaStreamWaitEvent(smgr->stream(0), evt_stash, 0);
cudaEventDestroy(evt_stash);
auto dtype = getNcclDataType(t.scalar_type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(),
"fmoe_cuda_reduce_grad", ([&] {
void* buf = (void*)t.data_ptr<scalar_t>();
NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
dtype,
ncclSum, root,
smgr->ncclcomm, smgr->stream(0)));
})
);
}
std::vector<torch::Tensor> _smart_sch_forward( std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
torch::Tensor stored_models, torch::Tensor stored_models,
long global_batch_size, long global_batch_size,
long expert_size,
long n_workers, long n_workers,
py::function forward_fn) { py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn) {
if (pipeline_gran == -1) { if (pipeline_gran == -1) {
char* p = getenv("FMOE_FASTER_GROUP_SIZE"); char* p = getenv("FMOE_FASTER_GROUP_SIZE");
if (p) { if (p) {
...@@ -47,14 +96,28 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -47,14 +96,28 @@ std::vector<torch::Tensor> _smart_sch_forward(
// TODO: maybe empty is faster // TODO: maybe empty is faster
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model}); auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model}); auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model}); auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), std::vector<torch::Tensor> params;
auto stored_models_ = stored_models.data_ptr<bool>();
for (long i = 0; i < num_expert * n_workers; ++i) {
if (stored_models_[i]) {
torch::Tensor t = input_buf.new_empty({expert_size});
if (i / num_expert == rank) {
get_param_fn(t);
}
params.push_back(t);
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_smart_sch_forward", ([&] { "fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl( fmoe_cuda_fused_forward_impl(
forward_fn, forward_fn,
stash_fn,
pop_fn,
input_buf.device(), input_buf.device(),
params,
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(), global_input_buf.data_ptr<scalar_t>(),
...@@ -64,7 +127,7 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -64,7 +127,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
local_expert_count.data_ptr<long>(), local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(), global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(), stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers, d_model, num_expert, rank, n_workers, expert_size,
pipeline_gran, smgr); pipeline_gran, smgr);
})); }));
return {output_buf, global_input_buf}; return {output_buf, global_input_buf};
...@@ -78,7 +141,11 @@ torch::Tensor _smart_sch_backward( ...@@ -78,7 +141,11 @@ torch::Tensor _smart_sch_backward(
long buf_batch_size, long buf_batch_size,
long global_batch_size, long global_batch_size,
long n_workers, long n_workers,
py::function backward_fn) { py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn) {
const auto num_expert = local_expert_count.size(0) / n_workers; const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(grad_out.device().index()); auto smgr = getCudaStreamManager(grad_out.device().index());
int rank; int rank;
...@@ -88,10 +155,14 @@ torch::Tensor _smart_sch_backward( ...@@ -88,10 +155,14 @@ torch::Tensor _smart_sch_backward(
auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model}); auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model}); auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
"fmoe_cuda_smartsch_backward", ([&] { "fmoe_cuda_smartsch_backward", ([&] {
fmoe_cuda_fused_backward_impl( fmoe_cuda_fused_backward_impl(
backward_fn, backward_fn,
stash_fn,
pop_fn,
collect_fn,
set_grad_fn,
grad_out.device(), grad_out.device(),
grad_out.data_ptr<scalar_t>(), grad_out.data_ptr<scalar_t>(),
...@@ -105,7 +176,7 @@ torch::Tensor _smart_sch_backward( ...@@ -105,7 +176,7 @@ torch::Tensor _smart_sch_backward(
d_model, num_expert, rank, n_workers, d_model, num_expert, rank, n_workers,
pipeline_gran, smgr); pipeline_gran, smgr);
})); }));
return {grad_in,}; return grad_in;
} }
#endif #endif
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
template<typename scalar_t> template<typename scalar_t>
void _exchange_with( void exchangeWith(
const scalar_t* sendbuf, size_t sendcount, int t_send, const scalar_t* sendbuf, size_t sendcount, int t_send,
scalar_t* recvbuf, size_t recvcount, int t_recv, scalar_t* recvbuf, size_t recvcount, int t_recv,
long d_model, long d_model,
...@@ -39,15 +39,16 @@ void _exchange_with( ...@@ -39,15 +39,16 @@ void _exchange_with(
int gidx_recv = ei * world_size + rank_recv; \ int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert; int idx_self = ei + rank * num_expert;
void _compute_ptrs(long num_expert, long rank, long world_size,
const long* local_expert_count, void computePtrs(long num_expert, long rank, long world_size,
const long* global_expert_count, const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models, const bool* stored_models,
int *local_ptr, int *local_ptr,
int *global_ptr, int *global_ptr,
int *local_global_ptr) { int *local_global_ptr) {
local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0; local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) { for (int i = 0; i < num_expert * world_size; ++i) {
local_ptr[i + 1] = local_ptr[i] + local_expert_count[i]; local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];
...@@ -73,10 +74,11 @@ void _compute_ptrs(long num_expert, long rank, long world_size, ...@@ -73,10 +74,11 @@ void _compute_ptrs(long num_expert, long rank, long world_size,
} }
} }
template<typename scalar_t> template<typename scalar_t>
void _compute_fn(py::function fn, c10::Device device, void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf, scalar_t* inp_buf, scalar_t* out_buf,
int ei, long step, long offset, long micro_batch_size, long d_model, long idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value) .dtype(c10::CppTypeToScalarType<scalar_t>::value)
...@@ -87,7 +89,7 @@ void _compute_fn(py::function fn, c10::Device device, ...@@ -87,7 +89,7 @@ void _compute_fn(py::function fn, c10::Device device,
auto oup = torch::from_blob(out_buf + offset * d_model, auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options); {micro_batch_size, d_model}, options);
smgr->use_default = true; smgr->use_default = true;
fn(inp, oup, step); fn(inp, oup, idx);
smgr->use_default = false; smgr->use_default = false;
} }
...@@ -95,25 +97,29 @@ void _compute_fn(py::function fn, c10::Device device, ...@@ -95,25 +97,29 @@ void _compute_fn(py::function fn, c10::Device device,
template<typename scalar_t> template<typename scalar_t>
void fmoe_cuda_fused_forward_impl( void fmoe_cuda_fused_forward_impl(
py::function forward_fn, py::function forward_fn,
py::function stash_fn,
py::function pop_fn,
c10::Device device, c10::Device device,
std::vector<torch::Tensor> params,
const scalar_t* input_buf, scalar_t* input_buf,
scalar_t* global_input_buf, scalar_t* global_input_buf,
scalar_t* global_output_buf, scalar_t* global_output_buf,
scalar_t* output_buf, scalar_t* output_buf,
const long* local_expert_count, const long* local_expert_count,
const long* global_expert_count, const long* global_expert_count,
const bool* stored_models, const bool* stored_models,
long d_model, long d_model,
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size, computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models, local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr); local_ptr, global_ptr, local_global_ptr);
...@@ -130,15 +136,16 @@ void fmoe_cuda_fused_forward_impl( ...@@ -130,15 +136,16 @@ void fmoe_cuda_fused_forward_impl(
cudaEventCreate(output_ready + i); cudaEventCreate(output_ready + i);
} }
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) { for (long ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) { for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + to_base; int rank_send = j + to_base;
int rank_recv = j + from_base; int rank_recv = j + from_base;
GEN_IDX; GEN_IDX;
_exchange_with(input_buf + local_ptr[idx_send] * d_model, exchangeWith(input_buf + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model, global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
...@@ -149,22 +156,59 @@ void fmoe_cuda_fused_forward_impl( ...@@ -149,22 +156,59 @@ void fmoe_cuda_fused_forward_impl(
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(0));
} }
// Broadcast shadowed experts
cudaEvent_t evt_get, *evt_shadow;
if (params.size() > 0) {
evt_shadow = new cudaEvent_t[params.size()];
}
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
cudaStreamWaitEvent(smgr->stream(1), evt_get);
cudaEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0)));
cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0));
++si;
}
}
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0); cudaStreamWaitEvent(torch_stream, input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size + long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset; (from_base + pipeline_gran)] - offset;
computeFn(forward_fn, device,
_compute_fn(forward_fn, device,
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
ei, step, offset, micro_batch_size, d_model, smgr); step, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(params[si], si);
cudaStreamWaitEvent(torch_stream, evt_shadow[si], 0);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device,
input_buf, output_buf,
n_groups + si, offset, micro_batch_size, d_model, smgr);
++si;
} }
auto stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEventRecord(output_ready[step], stream);
} }
pop_fn();
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0); cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
...@@ -174,7 +218,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -174,7 +218,7 @@ void fmoe_cuda_fused_forward_impl(
int rank_send = j + from_base; int rank_send = j + from_base;
int rank_recv = j + to_base; int rank_recv = j + to_base;
GEN_IDX; GEN_IDX;
_exchange_with(global_output_buf + global_ptr[gidx_send] * d_model, exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model, output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
...@@ -184,31 +228,6 @@ void fmoe_cuda_fused_forward_impl( ...@@ -184,31 +228,6 @@ void fmoe_cuda_fused_forward_impl(
} }
} }
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_forward(
input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden,
smgr->stream(stream), smgr->handle(stream));
}
}*/
delete [] local_ptr; delete [] local_ptr;
delete [] global_ptr; delete [] global_ptr;
delete [] local_global_ptr; delete [] local_global_ptr;
...@@ -217,6 +236,9 @@ void fmoe_cuda_fused_forward_impl( ...@@ -217,6 +236,9 @@ void fmoe_cuda_fused_forward_impl(
cudaEventDestroy(input_ready[i]); cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]); cudaEventDestroy(output_ready[i]);
} }
for (unsigned i = 0; i < params.size(); ++i) {
cudaEventDestroy(evt_shadow[i]);
}
delete [] input_ready; delete [] input_ready;
delete [] output_ready; delete [] output_ready;
} }
...@@ -225,6 +247,10 @@ void fmoe_cuda_fused_forward_impl( ...@@ -225,6 +247,10 @@ void fmoe_cuda_fused_forward_impl(
template<typename scalar_t> template<typename scalar_t>
void fmoe_cuda_fused_backward_impl( void fmoe_cuda_fused_backward_impl(
py::function backward_fn, py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn,
c10::Device device, c10::Device device,
scalar_t* grad_out, scalar_t* grad_out,
...@@ -232,21 +258,21 @@ void fmoe_cuda_fused_backward_impl( ...@@ -232,21 +258,21 @@ void fmoe_cuda_fused_backward_impl(
scalar_t* global_grad_in, scalar_t* global_grad_in,
scalar_t* grad_in, scalar_t* grad_in,
const long* local_expert_count, const long* local_expert_count,
const long* global_expert_count, const long* global_expert_count,
const bool* stored_models, const bool* stored_models,
long d_model, long d_model,
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
_compute_ptrs(num_expert, rank, world_size, computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models, local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr); local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) { if (pipeline_gran > world_size) {
pipeline_gran = world_size; pipeline_gran = world_size;
} }
...@@ -260,6 +286,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -260,6 +286,7 @@ void fmoe_cuda_fused_backward_impl(
cudaEventCreate(output_ready + i); cudaEventCreate(output_ready + i);
} }
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
...@@ -268,7 +295,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -268,7 +295,7 @@ void fmoe_cuda_fused_backward_impl(
int rank_send = j + to_base; int rank_send = j + to_base;
int rank_recv = j + from_base; int rank_recv = j + from_base;
GEN_IDX; GEN_IDX;
_exchange_with(grad_out + local_ptr[idx_send] * d_model, exchangeWith(grad_out + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model, global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
...@@ -279,21 +306,54 @@ void fmoe_cuda_fused_backward_impl( ...@@ -279,21 +306,54 @@ void fmoe_cuda_fused_backward_impl(
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(0));
} }
// Shadowed experts backward and reduce
cudaEvent_t *evt_reduce = new cudaEvent_t[num_expert];
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(si);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device,
grad_out, grad_in,
n_groups + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
}
++si;
}
}
pop_fn();
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0); cudaStreamWaitEvent(smgr->stream(1), input_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size + long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset; (from_base + pipeline_gran)] - offset;
_compute_fn(backward_fn, device, computeFn(backward_fn, device,
global_grad_out, global_grad_in, global_grad_out, global_grad_in,
ei, step, offset, micro_batch_size, d_model, smgr); step, offset, micro_batch_size, d_model, smgr);
} }
// TODO: get pytorch's compute stream cudaEventRecord(output_ready[step], torch_stream);
} }
// Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaStreamWaitEvent(torch_stream, evt_reduce[i % num_expert], 0);
set_grad_fn(si);
}
++si;
}
}
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0); cudaStreamWaitEvent(smgr->stream(0), output_ready[step], 0);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
...@@ -303,7 +363,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -303,7 +363,7 @@ void fmoe_cuda_fused_backward_impl(
int rank_send = j + from_base; int rank_send = j + from_base;
int rank_recv = j + to_base; int rank_recv = j + to_base;
GEN_IDX; GEN_IDX;
_exchange_with(global_grad_in + global_ptr[gidx_send] * d_model, exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model, grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
...@@ -315,36 +375,6 @@ void fmoe_cuda_fused_backward_impl( ...@@ -315,36 +375,6 @@ void fmoe_cuda_fused_backward_impl(
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_backward(
original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
smgr->stream(stream), smgr->handle(stream));
}
}
*/
delete [] local_ptr; delete [] local_ptr;
delete [] global_ptr; delete [] global_ptr;
delete [] local_global_ptr; delete [] local_global_ptr;
...@@ -355,6 +385,12 @@ void fmoe_cuda_fused_backward_impl( ...@@ -355,6 +385,12 @@ void fmoe_cuda_fused_backward_impl(
} }
delete [] input_ready; delete [] input_ready;
delete [] output_ready; delete [] output_ready;
for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) {
cudaEventDestroy(evt_reduce[i]);
}
}
delete [] evt_reduce;
} }
#endif // SMART_SCHEDULE_H #endif // SMART_SCHEDULE_H
...@@ -63,8 +63,13 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -63,8 +63,13 @@ std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
torch::Tensor stored_models, torch::Tensor stored_models,
long global_batch_size, long n_workers, long global_batch_size,
py::function forward_fn); long expert_size,
long n_workers,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn);
torch::Tensor _smart_sch_backward( torch::Tensor _smart_sch_backward(
torch::Tensor grad_out, torch::Tensor grad_out,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
...@@ -73,7 +78,15 @@ torch::Tensor _smart_sch_backward( ...@@ -73,7 +78,15 @@ torch::Tensor _smart_sch_backward(
long buf_batch_size, long buf_batch_size,
long global_batch_size, long global_batch_size,
long n_workers, long n_workers,
py::function backward_fn); py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn);
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
...@@ -85,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -85,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling"); m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling"); m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling");
m.def("reduce_grad", &_reduce_grad, "Reduce gradients over FastMoE's communication stream");
#endif #endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)"); m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
......
...@@ -3,21 +3,34 @@ ...@@ -3,21 +3,34 @@
#include <cassert> #include <cassert>
#include <thread> #include <thread>
#include <iostream> #include <iostream>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "fastermoe/status.h"
#include "stream_manager.h" #include "stream_manager.h"
#define SMGR_N_STREAMS 16 #define SMGR_N_STREAMS 16
cudaStream_t CudaStreamManager::stream(size_t idx) { cudaStream_t CudaStreamManager::stream(size_t idx) {
if (this->use_default) {
return c10::cuda::getCurrentCUDAStream().stream();
}
return this->streams[idx % SMGR_N_STREAMS]; return this->streams[idx % SMGR_N_STREAMS];
} }
cublasHandle_t CudaStreamManager::handle(size_t idx) { cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle();
}
return this->handles[idx % SMGR_N_STREAMS]; return this->handles[idx % SMGR_N_STREAMS];
} }
void CudaStreamManager::sync(int idx) { void CudaStreamManager::sync(int idx) {
if (this->use_default) {
return;
}
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) { for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
} }
......
...@@ -21,13 +21,14 @@ public: ...@@ -21,13 +21,14 @@ public:
int device; int device;
cublasHandle_t* handles; cublasHandle_t* handles;
cudaStream_t* streams; cudaStream_t* streams;
bool use_default;
#ifdef FMOE_USE_NCCL #ifdef FMOE_USE_NCCL
char ncclgood; char ncclgood;
ncclComm_t ncclcomm; ncclComm_t ncclcomm;
#endif #endif
public: public:
CudaStreamManager(int device_): device(device_) { CudaStreamManager(int device_): device(device_), use_default(false) {
this->setup(device); this->setup(device);
} }
......
import os
def float_from_env(key, default=-1):
if key in os.environ:
return float(os.environ[key])
return default
def switch_from_env(key, default=False):
if key in os.environ:
return os.environ[key] in ['1', 'ON']
return default
import torch
def get_expert_param_size(e):
return sum(map(lambda x: x.numel(), e.parameters()))
def get_expert_params(e, out):
offset = 0
for n, p in e.named_parameters():
seg = out[offset:offset + p.numel()]
offset += p.numel()
seg.copy_(p.data.flatten())
def stash_expert_params(e, params):
if not hasattr(e, 'expert_param_stash'):
setattr(e, 'expert_param_stash', dict())
offset = 0
for n, p in e.named_parameters():
if n not in e.expert_param_stash:
e.expert_param_stash[n] = p.data.clone()
with torch.no_grad():
seg = params[offset:offset + p.numel()]
offset += p.numel()
p.copy_(seg.reshape(p.shape))
def pop_expert_params(e):
if not hasattr(e, 'expert_param_stash'):
return
for n, p in e.named_parameters():
with torch.no_grad():
p.copy_(e.expert_param_stash[n])
e.expert_param_stash.clear()
def collect_expert_grads(e, grads):
offset = 0
for _, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is not None:
seg.copy_(p.grad.flatten())
p.grad = None
else:
seg.zero_()
def set_grads(e, grads):
offset = 0
for n, p in e.named_parameters():
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is None:
p.grad = seg.clone()
else:
p.grad += seg.reshape(p.shape)
...@@ -7,6 +7,9 @@ from torch.autograd.function import Function ...@@ -7,6 +7,9 @@ from torch.autograd.function import Function
from fmoe.functions import prepare_forward, ensure_comm from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather from fmoe.functions import _local_scatter, _local_gather
import fmoe_cuda as fmoe_native import fmoe_cuda as fmoe_native
from fmoe.fastermoe import expert_utils
from .shadow_policy import get_shadow_policy
class MoEForward(Function): class MoEForward(Function):
...@@ -14,6 +17,7 @@ class MoEForward(Function): ...@@ -14,6 +17,7 @@ class MoEForward(Function):
def forward( def forward(
ctx, ctx,
expert_fn, expert_fn,
experts,
inp, # models, inp, # models,
pos_s, pos_g, pos_s, pos_g,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
...@@ -22,31 +26,43 @@ class MoEForward(Function): ...@@ -22,31 +26,43 @@ class MoEForward(Function):
world_size): world_size):
local_input_buf = _local_scatter(inp, pos_s) local_input_buf = _local_scatter(inp, pos_s)
# TODO: leave this for furture work of expert shadowing ctx.gibs = [None] * (world_size * 2)
# model_params = [[tuple(m.parameters()) for m in node] for node in models] ctx.gobs = [None] * (world_size * 2)
ctx.gibs = [None] * world_size
ctx.gobs = [None] * world_size
def _expert_forward(x, y, idx): def _expert_forward(x, y, idx):
nothing = lambda a: a
x = x.data x = x.data
with torch.enable_grad(): with torch.enable_grad():
x.requires_grad = True x.requires_grad = True
y0 = expert_fn(x, [x.shape[0]]) # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]])
ctx.gibs[idx] = x ctx.gibs[idx] = x
ctx.gobs[idx] = y0 ctx.gobs[idx] = y0
y.copy_(y0) y.copy_(y0)
ctx.experts = experts
if stored_models.any():
ctx.expert_size = expert_utils.get_expert_param_size(experts)
else:
ctx.expert_size = 0
get_param_fn = lambda out: expert_utils.get_expert_params(experts, out)
pop_fn = lambda: expert_utils.pop_expert_params(experts)
ctx.shadows = [None] * world_size
def stash_fn(params, idx):
expert_utils.stash_expert_params(experts, params)
ctx.shadows[idx] = params
local_output_buf, gib = fmoe_native.smart_sch_forward( local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf, local_input_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
stored_models, fwd_batch_size, stored_models, fwd_batch_size, ctx.expert_size,
world_size, _expert_forward) world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)
out = _local_gather(local_output_buf, pos_g, out_batch_size, out = _local_gather(local_output_buf, pos_g, out_batch_size,
maybe_overlap=False) maybe_overlap=False)
variables = (pos_s, pos_g, local_expert_count, global_expert_count, variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, gib) stored_models, gib, local_input_buf)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
...@@ -56,28 +72,42 @@ class MoEForward(Function): ...@@ -56,28 +72,42 @@ class MoEForward(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count, (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models, _) = ctx.saved_tensors stored_models, _1, _2) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
def _expert_backward(grad_y, grad_x, idx): def _expert_backward(grad_y, grad_x, idx):
y = ctx.gobs[idx] y = ctx.gobs[idx]
torch.autograd.backward([y], [grad_y])
x = ctx.gibs[idx] x = ctx.gibs[idx]
torch.autograd.backward([y], [grad_y])
grad_x.copy_(x.grad) grad_x.copy_(x.grad)
experts = ctx.experts
def stash_fn(idx):
expert_utils.stash_expert_params(experts, ctx.shadows[idx])
pop_fn = lambda: expert_utils.pop_expert_params(experts)
def collect_fn(idx, root):
grad = ctx.shadows[idx]
expert_utils.collect_expert_grads(experts, grad)
fmoe_native.reduce_grad(grad, root, ctx.expert_size)
set_grad_fn = lambda idx: expert_utils.set_grads(experts, ctx.shadows[idx])
grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g) grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
grad_in_buf = fmoe_native.smart_sch_backward( grad_in_buf = fmoe_native.smart_sch_backward(
grad_out_buf, grad_out_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
stored_models, stored_models,
pos_s.shape[0], fwd_batch_size, pos_s.shape[0], fwd_batch_size,
world_size, _expert_backward) world_size,
_expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size) grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)
return (None, grad_in, None, None, None, None, None, None, None, None) return (None, None, grad_in, None, None, None, None, None, None, None, None)
policy_fn = None
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
# TODO: Using multiple tensors as input is to be supported. # TODO: Using multiple tensors as input is to be supported.
assert(isinstance(inp, torch.Tensor)) assert(isinstance(inp, torch.Tensor))
# TODO: Support many experts on each process # TODO: Support many experts on each process
...@@ -90,15 +120,20 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size): ...@@ -90,15 +120,20 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
fwd_batch_size, fwd_batch_size,
) = prepare_forward(gate, n_expert, world_size) ) = prepare_forward(gate, n_expert, world_size)
# TODO: Expert shadowing is to be supported. Currently using all 0s global policy_fn
stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool) if policy_fn is None:
policy_fn = get_shadow_policy(d_model=inp.shape[-1])
if stored_models is None:
stored_models = policy_fn(local_expert_count, global_expert_count,
n_expert, world_size)
topk = 1 topk = 1
if len(gate.shape) == 2: if len(gate.shape) == 2:
topk = gate.shape[1] topk = gate.shape[1]
out_batch_size = inp.shape[0] * topk out_batch_size = inp.shape[0] * topk
return MoEForward.apply(expert_fn, inp, return MoEForward.apply(expert_fn, experts, inp,
torch.div(pos, topk, rounding_mode='floor'), pos, torch.div(pos, topk, rounding_mode='floor'), pos,
local_expert_count, global_expert_count, stored_models, local_expert_count, global_expert_count, stored_models,
fwd_batch_size, out_batch_size, world_size) fwd_batch_size, out_batch_size, world_size)
import os
import torch
import torch.distributed as dist
from .config import float_from_env, switch_from_env
from fmoe.functions import get_moe_group
def global_policy(local_expert_count, _gec, num_expert, world_size):
r"""
This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
A few parameters are used in this policy.
* `d_model`: feature length of the MLP input and output.
* `alpha`: the ratio of the MLP's hidden size to `d_model`.
* `bw_net`: bandwidth of the network (GBps)
* `bw_mm`: computation throughput of performing GeMM (FLOPs)
"""
bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8)
bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12)
alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2)
d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048)
moe_group = get_moe_group()
local_expert_count = local_expert_count.cuda()
agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)]
dist.all_gather(agecs, local_expert_count, group=moe_group)
all_global_expert_count = torch.stack(agecs)
# TODO: data type other than float
data_size = 4
fwd_expert_counts = all_global_expert_count.sum(1).cpu()
B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True)
alphaH2 = alpha * (d_model ** 2)
B_w = B_ws[0]
comm = float('+inf')
send_feature_time = d_model * data_size / bw_net
send_model_time = 2 * alphaH2 * data_size / bw_net
comp_time = 4 * alphaH2 / bw_mm
lat_base = 3 * comp_time * B_w + 4 * send_feature_time * B_w
res = torch.zeros(world_size * num_expert, dtype=torch.bool)
shadow_time = 0
for i, index in enumerate(indices):
if i + 1 == indices.numel():
break
B_k = B_ws[i + 1]
shadow_time += send_model_time
lat_new = 3 * comp_time * B_k + 4 * send_feature_time * B_k + shadow_time
if lat_new < lat_base:
lat_base = lat_new
res[index] = True
else:
break
return res
def no_shadow_policy(_lec, _gec, num_expert, world_size):
res = torch.zeros(world_size * num_expert, dtype=bool)
return res
def get_shadow_policy(d_model=None):
if d_model is not None and 'FMOE_FASTER_GLBPLC_DMODEL' not in os.environ:
os.environ['FMOE_FASTER_GLBPLC_DMODEL'] = str(d_model)
if not switch_from_env('FMOE_FASTER_SHADOW_ENABLE'):
return no_policy
return global_policy
...@@ -10,12 +10,21 @@ import fmoe_cuda ...@@ -10,12 +10,21 @@ import fmoe_cuda
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
_moe_group = None
def ensure_comm(t, comm): def ensure_comm(t, comm):
if comm is None: if comm is None:
comm = get_torch_default_comm() comm = get_torch_default_comm()
global _moe_group
_moe_group = comm
fmoe_cuda.ensure_nccl(comm, t) fmoe_cuda.ensure_nccl(comm, t)
def get_moe_group():
return _moe_group
def count_by_gate(gate, num_expert, world_size, require_pos=True): def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad(): with torch.no_grad():
local_expert_count = torch.zeros( local_expert_count = torch.zeros(
......
...@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather ...@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
from .fastermoe.config import switch_from_env
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
...@@ -21,7 +23,7 @@ def mark_module_parallel_comm(module, comm): ...@@ -21,7 +23,7 @@ def mark_module_parallel_comm(module, comm):
setattr(p, "dp_comm", comm) setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, **kwargs):
r""" r"""
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
...@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size): ...@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
return outp return outp
if os.environ.get('FMOE_FASTER_SCHEDULE_ENABLE', '0') in ['1', 'ON']: fmoe_faster_schedule = False
if switch_from_env('FMOE_FASTER_SCHEDULE_ENABLE', False):
fmoe_faster_schedule = True
from .fastermoe.schedule import _fmoe_general_global_forward from .fastermoe.schedule import _fmoe_general_global_forward
...@@ -227,7 +231,9 @@ class FMoE(nn.Module): ...@@ -227,7 +231,9 @@ class FMoE(nn.Module):
gate_top_k_idx = gate_top_k_idx[mask == 0, :] gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward( fwd = _fmoe_general_global_forward(
moe_inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size moe_inp, gate_top_k_idx, self.expert_fn,
self.num_expert, self.world_size,
experts=self.experts
) )
# recover deleted tensors # recover deleted tensors
......
...@@ -3,6 +3,7 @@ import random ...@@ -3,6 +3,7 @@ import random
import os import os
import sys import sys
from typing import Dict from typing import Dict
import random
import pytest import pytest
import torch import torch
...@@ -19,7 +20,6 @@ def _ensure_initialized(): ...@@ -19,7 +20,6 @@ def _ensure_initialized():
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost") os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
if not dist.is_initialized(): if not dist.is_initialized():
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed
from test_numerical import _assert_numerical
from fmoe.fastermoe.schedule import _fmoe_general_global_forward as smart_fwd
from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16, 512])
@pytest.mark.parametrize("n_expert", [1])
@pytest.mark.parametrize("group_sz", [1, 2, 4])
@pytest.mark.parametrize("pass_stored", [False, True])
def test_faster_shadow(n_process, d_model, batch_size, n_expert, group_sz, pass_stored):
_run_distributed('_test_faster_shadow',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'pass_stored': pass_stored
},
script=__file__,
env=dict(
FMOE_FASTER_GROUP_SIZE=str(group_sz),
FMOE_FASTER_SHADOW_ENABLE='ON'
)
)
def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
_ensure_initialized()
rank = dist.get_rank()
world_size = dist.get_world_size()
x1 = torch.rand(batch_size, d_model).cuda()
x1.requires_grad = True
x2 = x1.data.clone()
x2.requires_grad = True
topk_idx = torch.randint(0, world_size * n_expert, (batch_size, 2)).cuda()
m1 = torch.nn.Linear(d_model, d_model).cuda()
m2 = torch.nn.Linear(d_model, d_model).cuda()
with torch.no_grad():
m2.weight.copy_(m1.weight)
m2.bias.copy_(m1.bias)
def ef1(x, fec):
y = m1(x)
return y
def ef2(x, fec):
y = m2(x)
return y
if pass_stored:
stored_models = torch.randint(0, 2, (world_size,)).bool().cuda()
dist.broadcast(stored_models, 0)
stored_models = stored_models.cpu()
# if rank == 0:
# print('stored models {}'.format(stored_models))
ensure_comm(x1, None)
if pass_stored:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1,
stored_models=stored_models)
else:
y1 = smart_fwd(x1, topk_idx, ef1, n_expert, world_size, experts=m1)
y1.sum().backward()
y2 = naive_fwd(x2, topk_idx, ef2, n_expert, world_size, experts=m2)
y2.sum().backward()
_assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
[y1, x1.grad, m1.bias.grad, m1.weight.grad],
[y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
# test_faster_shadow(8, 16, 16, 1, 2)
_test_faster_shadow(1024, 16, 1, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment