Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2d250fbf
Commit
2d250fbf
authored
Jan 28, 2021
by
Rick Ho
Browse files
make test run on nccl version, but fails in correctness
parent
293eef6d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
88 additions
and
46 deletions
+88
-46
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+45
-30
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+3
-0
cuda/moe_fused_kernel.cu
cuda/moe_fused_kernel.cu
+4
-5
fmoe/moe_function.py
fmoe/moe_function.py
+17
-3
setup.py
setup.py
+0
-1
tests/dev_test.sh
tests/dev_test.sh
+2
-0
tests/moe_test.py
tests/moe_test.py
+17
-7
No files found.
cuda/moe_comm_kernel.cu
View file @
2d250fbf
...
...
@@ -13,50 +13,60 @@
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
void
moe_cuda_expert_exchange_impl
(
const
int
*
local_expert_count
,
int
*
global_expert_count
,
int
*
fwd_expert_count
,
int
num_expert
,
int
world_size
)
{
MPI_Alltoall
(
local_expert_count
,
num_expert
,
MPI_INT
,
global_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
fwd_expert_count
[
i
]
+=
global_expert_count
[
i
+
j
*
num_expert
];
}
const
long
*
local_expert_count
,
long
*
global_expert_count
,
int
num_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
num_expert
*
i
,
num_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
num_expert
*
i
,
num_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
fwe_options
=
torch
::
TensorOptions
()
.
dtype
(
local_expert_count
.
dtype
());
auto
fwd_expert_count
=
torch
::
zeros
({
num_expert
},
fwe_options
);
auto
smgr
=
getCudaStreamManager
(
local_expert_count
.
device
().
index
());
moe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
fwd
_expert
_count
.
data_ptr
<
int
>
()
,
num_expert
,
n_workers
);
return
{
global_expert_count
,
fwd_expert_count
};
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
num
_expert
,
n_workers
,
smgr
);
return
{
global_expert_count
};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_scatter_impl
(
const
scalar_t
*
local_input_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
input_buf
,
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
// assert world_size > 1
int
recv_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
...
...
@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
"moe_cuda_global_scatter"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
in_feat
,
num_expert
,
n_workers
,
smgr
...
...
@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
template
<
typename
scalar_t
>
void
moe_cuda_global_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
int
send_ptr
=
0
;
long
send_ptr
=
0
;
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
...
...
@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
num_expert
,
n_workers
,
smgr
...
...
@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
return
{
local_output_buf
,};
}
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
0
);
smgr
->
ensure
((
void
*
)
&
p
,
t
.
device
());
}
#endif
cuda/moe_cuda_kernel.h
View file @
2d250fbf
...
...
@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
#include <c10d/ProcessGroupNCCL.hpp>
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
,
torch
::
Tensor
t
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
);
...
...
cuda/moe_fused_kernel.cu
View file @
2d250fbf
...
...
@@ -14,7 +14,6 @@
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
template
<
typename
scalar_t
>
...
...
@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl(
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
local_expert_count
,
const
long
*
global_expert_count
,
long
in_feat
,
long
out_feat
,
long
num_expert
,
long
world_size
,
CudaStreamManager
*
smgr
)
{
...
...
@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
in_feat
,
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
...
...
fmoe/moe_function.py
View file @
2d250fbf
...
...
@@ -38,16 +38,30 @@ class MOELocal(Function):
class
MOEGlobal
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
fmoe_cuda
.
ensure_nccl
(
torch
.
distributed
.
distributed_c10d
.
_default_pg
,
inp
)
num_expert
=
weight
.
shape
[
0
]
local_expert_count
,
pos
=
fmoe_cuda
.
expert_count
(
gate
,
world_size
*
num_expert
)
global_expert_count
,
fwd_expert_count
=
fmoe_cuda
.
expert_exchange
(
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
# world_size * num_expert)
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
]
*
world_size
,
device
=
weight
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
print
(
'Local {} Global {}'
.
format
(
local_expert_count
,
global_expert_count
))
fwd_expert_count
=
global_expert_count
.
view
(
num_expert
,
world_size
).
sum
(
dim
=
1
).
cpu
()
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
local_expert_count
=
local_expert_count
.
cpu
()
global_expert_count
=
global_expert_count
.
cpu
()
local_output_buf
,
global_input_buf
=
fmoe_cuda
.
global_fused_forward
(
local_input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
...
...
setup.py
View file @
2d250fbf
...
...
@@ -8,7 +8,6 @@ cxx_flags = [
]
if
os
.
environ
.
get
(
'USE_NCCL'
,
'0'
)
==
'1'
:
cxx_flags
.
append
(
'-DMOE_USE_NCCL'
)
os
.
environ
[
'CXX'
]
=
'mpicxx'
if
__name__
==
'__main__'
:
setuptools
.
setup
(
...
...
tests/dev_test.sh
View file @
2d250fbf
...
...
@@ -2,6 +2,8 @@
if
[
!
-z
$OMPI_COMM_WORLD_LOCAL_RANK
]
then
export
CUDA_VISIBLE_DEVICES
=
$OMPI_COMM_WORLD_LOCAL_RANK
export
MASTER_ADDR
=
localhost
export
MASTER_PORT
=
36666
fi
if
[
-z
$OMPI_COMM_WORLD_RANK
]
...
...
tests/moe_test.py
View file @
2d250fbf
...
...
@@ -4,6 +4,7 @@ import torch
from
torch
import
nn
import
time
import
sys
import
os
dev_name_default
=
'cuda:0'
...
...
@@ -105,10 +106,10 @@ def test():
if
world_size
==
1
:
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
else
:
weight_array
=
[
torch
.
empty_like
(
moe
.
weight
.
data
)
.
cpu
()
weight_array
=
[
torch
.
empty_like
(
moe
.
weight
.
data
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
weight_array
,
moe
.
weight
.
data
.
cpu
()
)
moe_raw
.
weight
.
data
=
torch
.
cat
(
weight_array
,
dim
=
0
)
.
cuda
()
torch
.
distributed
.
all_gather
(
weight_array
,
moe
.
weight
.
data
)
moe_raw
.
weight
.
data
=
torch
.
cat
(
weight_array
,
dim
=
0
)
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
...
...
@@ -124,13 +125,20 @@ def test():
if
world_size
>
1
:
rank
=
torch
.
distributed
.
get_rank
()
ou
,
wg
,
lwg
,
lbg
=
raw_out
wg
=
wg
.
cpu
()
torch
.
distributed
.
all_reduce
(
wg
)
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
raw_out
=
ou
,
wg
.
cuda
(),
lwg
,
lbg
raw_out
=
ou
,
wg
,
lwg
,
lbg
else
:
rank
=
0
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'Rank {} {} abs err {}'
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
sys
.
stderr
.
write
(
'=========== moe out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
mo
))
sys
.
stderr
.
write
(
'=========== raw out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
ro
))
return
def
test_dp
():
...
...
@@ -158,7 +166,9 @@ def test_dp():
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
0
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
1
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
len
(
sys
.
argv
)
>=
2
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment