Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
46c3722d
Commit
46c3722d
authored
Mar 29, 2022
by
Rick Ho
Browse files
forward pass test
parent
49b5b5d6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
689 additions
and
149 deletions
+689
-149
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+133
-0
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+369
-0
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+12
-0
cuda/fused_exchange.cu
cuda/fused_exchange.cu
+0
-146
fmoe/fastermoe/__init__.py
fmoe/fastermoe/__init__.py
+0
-0
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+103
-0
fmoe/layers.py
fmoe/layers.py
+9
-2
setup.py
setup.py
+2
-1
tests/test_faster_schedule.py
tests/test_faster_schedule.py
+61
-0
No files found.
cuda/fastermoe/smart_schedule.cpp
0 → 100644
View file @
46c3722d
#ifdef FMOE_USE_NCCL
#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "smart_schedule.h"
long
pipeline_gran
=
-
1
;
torch
::
Tensor
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
stored_models
,
long
global_batch_size
,
long
n_workers
,
py
::
function
forward_fn
)
{
if
(
pipeline_gran
==
-
1
)
{
char
*
p
=
getenv
(
"FMOE_FASTER_GROUP_SIZE"
);
if
(
p
)
{
pipeline_gran
=
atoi
(
p
);
}
else
{
pipeline_gran
=
4
;
}
}
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
int
rank
;
NCCL_SAFE_CALL
(
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
));
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
const
auto
d_model
=
input_buf
.
size
(
1
);
auto
global_input_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
global_output_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
output_buf
=
input_buf
.
new_zeros
({
input_buf
.
size
(
0
),
d_model
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"fmoe_cuda_fused_forward"
,
([
&
]
{
fmoe_cuda_fused_forward_impl
(
forward_fn
,
input_buf
.
device
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
stored_models
.
data_ptr
<
bool
>
(),
d_model
,
num_expert
,
rank
,
n_workers
,
pipeline_gran
,
smgr
);
}));
return
output_buf
;
}
/*
std::vector<torch::Tensor> _fused_backward(
torch::Tensor input_buf,
std::vector<std::vector<std::vector<torch::Tensor>>> params,
torch::Tensor middle_buf,
torch::Tensor output_buf,
torch::Tensor grad_out,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor inp,
torch::Tensor stored_models,
long global_batch_size,
long buf_batch_size,
long n_workers, bool has_bias) {
const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(input_buf.device().index());
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
const auto d_hidden = params[rank][0][0].size(1);
const auto d_model = params[rank][0][0].size(2);
auto global_grad_out = input_buf.new_zeros({global_batch_size, d_model});
auto grad_middle = input_buf.new_zeros({global_batch_size, d_hidden});
auto global_grad_in = input_buf.new_zeros({global_batch_size, d_model});
auto grad_in = input_buf.new_zeros({buf_batch_size, d_model});
for (auto node : params)
for (auto expert : node)
for (int i = 0; i < expert.size(); i++) {
// create the respective gradient of each tensor
CHECK_INPUT(expert[i]);
if (expert[i].grad().defined()) {
CHECK_INPUT(expert[i].grad());
continue;
}
expert[i].mutable_grad() = input_buf.new_zeros(expert[i].sizes());
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"fmoe_cuda_fused_backward", ([&] {
fmoe_cuda_fused_backward_impl(
input_buf.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
params,
middle_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(),
global_grad_out.data_ptr<scalar_t>(),
global_grad_in.data_ptr<scalar_t>(),
grad_middle.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, d_hidden, num_expert, rank, n_workers, has_bias,
pipeline_gran, smgr);
}));
return {grad_in,};
}
*/
#endif
cuda/fastermoe/smart_schedule.h
0 → 100644
View file @
46c3722d
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include "../stream_manager.h"
template
<
typename
scalar_t
>
void
_exchange_with
(
const
scalar_t
*
sendbuf
,
size_t
sendcount
,
int
t_send
,
scalar_t
*
recvbuf
,
size_t
recvcount
,
int
t_recv
,
long
d_model
,
cudaStream_t
stream
,
ncclComm_t
comm
)
{
if
(
sendcount
)
{
ncclSend
(
sendbuf
,
sendcount
*
d_model
*
sizeof
(
scalar_t
),
ncclChar
,
t_send
,
comm
,
stream
);
}
if
(
recvcount
)
{
ncclRecv
(
recvbuf
,
recvcount
*
d_model
*
sizeof
(
scalar_t
),
ncclChar
,
t_recv
,
comm
,
stream
);
}
}
#define GEN_BASE(_step) \
long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
int idx_send = ei + rank_send * num_expert; \
int idx_recv = ei + rank_recv * num_expert; \
int gidx_send = ei * world_size + rank_send; \
int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert;
void
_compute_ptrs
(
long
num_expert
,
long
rank
,
long
world_size
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
const
bool
*
stored_models
,
int
*
local_ptr
,
int
*
global_ptr
,
int
*
local_global_ptr
)
{
local_ptr
[
0
]
=
global_ptr
[
0
]
=
local_global_ptr
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
*
world_size
;
++
i
)
{
local_ptr
[
i
+
1
]
=
local_ptr
[
i
]
+
local_expert_count
[
i
];
local_global_ptr
[
i
+
1
]
=
local_global_ptr
[
i
];
// if model fetched, add local tokens
if
(
stored_models
[
i
]){
local_global_ptr
[
i
+
1
]
+=
local_expert_count
[
i
];
}
auto
expert_idx
=
i
%
num_expert
;
auto
worker_idx
=
i
/
num_expert
;
auto
gp_idx
=
expert_idx
*
world_size
+
worker_idx
;
// if local model wasn't fetched, receive global tokens
if
(
stored_models
[
rank
*
num_expert
+
expert_idx
])
{
global_ptr
[
gp_idx
+
1
]
=
0
;
}
else
{
global_ptr
[
gp_idx
+
1
]
=
global_expert_count
[
i
];
}
}
global_ptr
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
*
world_size
;
++
i
)
{
global_ptr
[
i
+
1
]
+=
global_ptr
[
i
];
}
}
template
<
typename
scalar_t
>
void
_compute_forward
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
int
ei
,
long
step
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
c10
::
CppTypeToScalarType
<
scalar_t
>::
value
)
.
device
(
device
)
.
requires_grad
(
true
);
auto
inp
=
torch
::
from_blob
(
inp_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
fn
(
inp
,
oup
,
step
);
}
template
<
typename
scalar_t
>
void
_compute_backward
(
py
::
function
fn
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
long
*
local_expert_count
,
long
*
global_expert_count
,
int
ei
,
long
offset
,
long
micro_batch_size
,
long
d_model
)
{
}
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_forward_impl
(
py
::
function
forward_fn
,
c10
::
Device
device
,
const
scalar_t
*
input_buf
,
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
const
bool
*
stored_models
,
long
d_model
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
_compute_ptrs
(
num_expert
,
rank
,
world_size
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_ptr
,
global_ptr
,
local_global_ptr
);
if
(
pipeline_gran
>
world_size
)
{
pipeline_gran
=
world_size
;
}
long
n_groups
=
world_size
/
pipeline_gran
;
long
group_rank
=
rank
/
pipeline_gran
;
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
pipeline_gran
;
++
j
)
{
int
rank_send
=
j
+
to_base
;
int
rank_recv
=
j
+
from_base
;
GEN_IDX
;
_exchange_with
(
input_buf
+
local_ptr
[
idx_send
]
*
d_model
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_input_buf
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
input_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
_compute_forward
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
ei
,
step
,
offset
,
micro_batch_size
,
d_model
);
}
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaEventRecord
(
output_ready
[
step
],
stream
);
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
pipeline_gran
;
++
j
)
{
int
rank_send
=
j
+
from_base
;
int
rank_recv
=
j
+
to_base
;
GEN_IDX
;
_exchange_with
(
global_output_buf
+
global_ptr
[
gidx_send
]
*
d_model
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
output_buf
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_forward(
input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden,
smgr->stream(stream), smgr->handle(stream));
}
}*/
delete
[]
local_ptr
;
delete
[]
global_ptr
;
delete
[]
local_global_ptr
;
checkCudaErrors
(
cudaGetLastError
());
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
}
delete
[]
input_ready
;
delete
[]
output_ready
;
}
template
<
typename
scalar_t
>
void
fmoe_cuda_fused_backward_impl
(
py
::
function
backward_fn
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
output_buf
,
const
scalar_t
*
grad_out
,
scalar_t
*
global_grad_out
,
scalar_t
*
global_grad_in
,
scalar_t
*
grad_in
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
const
bool
*
stored_models
,
long
d_model
,
long
d_hidden
,
long
num_expert
,
long
rank
,
long
world_size
,
long
pipeline_gran
,
CudaStreamManager
*
smgr
)
{
int
*
local_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
int
*
local_global_ptr
=
new
int
[
num_expert
*
world_size
+
1
];
// local fetched models tracker
_compute_ptrs
(
num_expert
,
rank
,
world_size
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_ptr
,
global_ptr
,
local_global_ptr
);
if
(
pipeline_gran
>
world_size
)
{
pipeline_gran
=
world_size
;
}
long
n_groups
=
world_size
/
pipeline_gran
;
long
group_rank
=
rank
/
pipeline_gran
;
cudaEvent_t
*
input_ready
=
new
cudaEvent_t
[
n_groups
];
cudaEvent_t
*
output_ready
=
new
cudaEvent_t
[
n_groups
];
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventCreate
(
input_ready
+
i
);
cudaEventCreate
(
output_ready
+
i
);
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
pipeline_gran
;
++
j
)
{
int
rank_send
=
j
+
to_base
;
int
rank_recv
=
j
+
from_base
;
GEN_IDX
;
_exchange_with
(
grad_out
+
local_ptr
[
idx_send
]
*
d_model
,
local_expert_count
[
idx_send
]
*
!
stored_models
[
idx_send
],
rank_send
,
global_grad_out
+
global_ptr
[
gidx_recv
]
*
d_model
,
global_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_self
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
cudaEventRecord
(
input_ready
[
step
],
smgr
->
stream
(
0
));
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
input_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
micro_batch_size
=
global_ptr
[
ei
*
world_size
+
(
from_base
+
pipeline_gran
)]
-
offset
;
_compute_backward
(
backward_fn
,
input_buf
,
output_buf
,
global_grad_out
,
global_grad_in
,
ei
,
offset
,
micro_batch_size
);
}
// TODO: get pytorch's compute stream
}
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
cudaStreamWaitEvent
(
smgr
->
stream
(
0
),
output_ready
[
step
],
0
);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
pipeline_gran
;
++
j
)
{
int
rank_send
=
j
+
from_base
;
int
rank_recv
=
j
+
to_base
;
GEN_IDX
;
_exchange_with
(
global_grad_in
+
global_ptr
[
gidx_send
]
*
d_model
,
global_expert_count
[
idx_send
]
*
!
stored_models
[
idx_self
],
rank_send
,
grad_in
+
local_ptr
[
idx_recv
]
*
d_model
,
local_expert_count
[
idx_recv
]
*
!
stored_models
[
idx_recv
],
rank_recv
,
d_model
,
smgr
->
stream
(
0
),
smgr
->
ncclcomm
);
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
checkCudaErrors
(
cudaGetLastError
());
/* TODO: Shadowing support
int offset = global_ptr[world_size * num_expert];
for (int j = 0; j < world_size; j++) {
for (int i = 0; i < num_expert; i++) {
int idx = j * num_expert + i;
if (!stored_models[idx])
continue;
weight1 = params[j][0][0].data_ptr<scalar_t>();
weight2 = params[j][0][last].data_ptr<scalar_t>();
grad_weight1 = params[j][0][0].mutable_grad().data_ptr<scalar_t>();
grad_weight2 = params[j][0][last].mutable_grad().data_ptr<scalar_t>();
auto stream = 2 + (idx % (SMGR_N_STREAMS- 2));
_compute_mlp_backward(
original_input_buf + local_ptr[idx] * d_model, weight1, weight2,
middle_buf + (offset + local_global_ptr[idx]) * d_hidden, output_buf, grad_out + local_ptr[idx] * d_model,
grad_middle + (offset + local_global_ptr[idx]) * d_hidden, grad_weight1, grad_weight2, grad_in + local_ptr[idx] * d_model,
i,
0, local_expert_count[idx],
d_model, d_hidden, 0, // we never consider it to be the first since it's already initialized to zero and we are lazy
smgr->stream(stream), smgr->handle(stream));
}
}
*/
delete
[]
local_ptr
;
delete
[]
global_ptr
;
delete
[]
local_global_ptr
;
checkCudaErrors
(
cudaGetLastError
());
for
(
long
i
=
0
;
i
<
n_groups
;
++
i
)
{
cudaEventDestroy
(
input_ready
[
i
]);
cudaEventDestroy
(
output_ready
[
i
]);
}
delete
[]
input_ready
;
delete
[]
output_ready
;
}
#endif // SMART_SCHEDULE_H
cuda/fmoe_cuda.cpp
View file @
46c3722d
#include <iostream>
#include <vector>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/extension.h>
// global_exchange
...
...
@@ -56,6 +57,15 @@ std::vector<torch::Tensor> _swipe_once(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
capacity_tensor
,
long
n_expert
,
long
n_worker
,
long
bias
);
// smart scheduling
torch
::
Tensor
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
stored_models
,
long
global_batch_size
,
long
n_workers
,
py
::
function
forward_fn
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
m
.
def
(
"expert_exchange"
,
&
_expert_exchange
,
"FastMoE expert exchange (CUDA)"
);
...
...
@@ -63,6 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"global_gather"
,
&
_global_gather
,
"FastMoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
m
.
def
(
"swipe_once"
,
&
_swipe_once
,
"SWIPE balance strategy(CUDA)"
);
m
.
def
(
"smart_sch_forward"
,
&
_smart_sch_forward
,
"E2E MoE layer forward with smart scheduling"
);
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
...
...
cuda/fused_exchange.cu
deleted
100644 → 0
View file @
49b5b5d6
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef FMOE_USE_NCCL
#include <nccl.h>
template
<
typename
scalar_t
>
void
moe_cuda_global_fused_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
long
in_feat
,
long
out_feat
,
long
num_expert
,
long
world_size
,
CudaStreamManager
*
smgr
)
{
int
ptr
=
0
;
int
send_ptr
=
0
;
int
recv_ptr
=
0
;
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
int
expert_count
=
0
;
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
}
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
global_input_buf
+
recv_ptr
*
in_feat
,
global_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
recv_ptr
+=
global_expert_count
[
idx
];
expert_count
+=
global_expert_count
[
idx
];
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
,
in_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
global_input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
global_output_buf
+
out_feat
*
ptr
,
out_feat
));
ptr
+=
expert_count
;
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
global_output_buf
+
send_ptr
*
out_feat
,
global_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
send_ptr
+=
global_expert_count
[
idx
];
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
global_input_buf
=
input_buf
.
new_empty
({
global_batch_size
,
in_feat
});
auto
global_output_buf
=
input_buf
.
new_empty
({
global_batch_size
,
out_feat
});
auto
output_buf
=
input_buf
.
new_empty
({
local_batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_global_fused_forward"
,
([
&
]
{
moe_cuda_global_fused_forward_impl
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
in_feat
,
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
output_buf
,
global_input_buf
};
}
#endif
fmoe/fastermoe/__init__.py
0 → 100644
View file @
46c3722d
fmoe/fastermoe/schedule.py
0 → 100644
View file @
46c3722d
r
"""
The smart schedule proposed in FasterMoE.
"""
import
torch
from
torch.autograd.function
import
Function
from
fmoe.functions
import
prepare_forward
,
ensure_comm
from
fmoe.functions
import
_local_scatter
,
_local_gather
import
fmoe_cuda
as
fmoe_native
class
MoEForward
(
Function
):
@
staticmethod
def
forward
(
ctx
,
expert_fn
,
inp
,
# models,
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
world_size
):
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
# TODO: leave this for furture work of expert shadowing
# model_params = [[tuple(m.parameters()) for m in node] for node in models]
ctx
.
gibs
=
[
None
]
*
world_size
ctx
.
gobs
=
[
None
]
*
world_size
def
_expert_forward
(
x
,
y
,
idx
):
x
=
x
.
data
x
.
requires_grad
=
True
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
ctx
.
gibs
[
idx
]
=
x
ctx
.
gobs
[
idx
]
=
y0
y
.
copy_
(
y0
)
local_output_buf
=
fmoe_native
.
smart_sch_forward
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
world_size
,
_expert_forward
)
out
=
_local_gather
(
local_output_buf
,
pos_g
,
out_batch_size
,
maybe_overlap
=
False
)
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
)
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
save_for_backward
(
*
variables
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
gib
,
gmb
,
gob
,
stored_models
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad
,
idx
):
y
=
ctx
.
gobs
[
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad
])
x
=
ctx
.
gibs
[
idx
]
return
x
.
grad
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
gib
,
gmb
,
gob
,
grad_out_buf
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
pos_s
.
shape
[
0
],
world_size
,
_expert_backward
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
return
(
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
):
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
# TODO: Support many experts on each process
assert
(
n_expert
==
1
)
(
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
)
=
prepare_forward
(
gate
,
n_expert
,
world_size
)
# TODO: Expert shadowing is to be supported. Currently using all 0s
stored_models
=
torch
.
zeros
(
n_expert
*
world_size
,
dtype
=
torch
.
bool
)
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
topk
=
gate
.
shape
[
1
]
out_batch_size
=
inp
.
shape
[
0
]
*
topk
return
MoEForward
.
apply
(
expert_fn
,
inp
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
),
pos
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
world_size
)
fmoe/layers.py
View file @
46c3722d
...
...
@@ -2,6 +2,7 @@ r"""
FMoE core layer
"""
import
tree
import
os
import
torch
import
torch.nn
as
nn
...
...
@@ -46,7 +47,7 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
def
scatter_func
(
tensor
):
return
MOEScatter
.
apply
(
tensor
,
pos
//
topk
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
)
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
...
...
@@ -75,6 +76,10 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
return
outp
if
os
.
environ
.
get
(
'FMOE_FASTER_SCHEDULE_ENABLE'
,
'0'
)
in
[
'1'
,
'ON'
]:
from
.fastermoe.schedule
import
_fmoe_general_global_forward
class
FMoE
(
nn
.
Module
):
r
"""
A general moe implementation that supports an arbitrary module as the
...
...
@@ -149,10 +154,12 @@ class FMoE(nn.Module):
"""
if
self
.
experts_fused
:
return
self
.
experts
(
inp
,
fwd_expert_count
)
if
isinstance
(
fwd_expert_count
,
torch
.
Tensor
):
fwd_expert_count
=
fwd_expert_count
.
cpu
().
numpy
()
outputs
=
[]
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
]
.
item
()
batch_size
=
fwd_expert_count
[
i
]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
...
...
setup.py
View file @
46c3722d
...
...
@@ -43,7 +43,7 @@ if __name__ == '__main__':
author_email
=
'hja20@mails.tsinghua.edu.cn'
,
license
=
'Apache-2'
,
url
=
'https://github.com/laekov/fastmoe'
,
packages
=
[
'fmoe'
,
'fmoe.megatron'
,
'fmoe.gates'
],
packages
=
[
'fmoe'
,
'fmoe.megatron'
,
'fmoe.gates'
,
'fmoe.fastermoe'
],
ext_modules
=
[
CUDAExtension
(
name
=
'fmoe_cuda'
,
...
...
@@ -54,6 +54,7 @@ if __name__ == '__main__':
'cuda/global_exchange.cpp'
,
'cuda/parallel_linear.cu'
,
'cuda/fmoe_cuda.cpp'
,
'cuda/fastermoe/smart_schedule.cpp'
,
],
define_macros
=
define_macros
,
extra_compile_args
=
{
...
...
tests/test_faster_schedule.py
0 → 100644
View file @
46c3722d
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
test_ddp
import
_ensure_initialized
,
_run_distributed
from
test_numerical
import
_assert_numerical
from
fmoe.fastermoe.schedule
import
_fmoe_general_global_forward
as
smart_fwd
from
fmoe.layers
import
_fmoe_general_global_forward
as
naive_fwd
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
def
test_faster_schedule
(
n_process
,
d_model
,
batch_size
,
n_expert
):
_run_distributed
(
'_test_faster_schedule'
,
n_process
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
},
script
=
__file__
,
env
=
dict
()
)
def
_test_faster_schedule
(
d_model
,
batch_size
,
n_expert
):
_ensure_initialized
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
.
requires_grad
=
True
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
m
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
def
expert_fn
(
x
,
fec
):
y
=
m
(
x
)
return
y
ensure_comm
(
x
,
None
)
y
=
smart_fwd
(
x
,
topk_idx
,
expert_fn
,
n_expert
,
world_size
)
z
=
naive_fwd
(
x
,
topk_idx
,
expert_fn
,
n_expert
,
world_size
)
_assert_numerical
([
'out'
],
[
y
],
[
z
],
rank
)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
# test_faster_schedule(8, 16, 16, 1)
_test_faster_schedule
(
4
,
2
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment