Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2d250fbf
"tools/vscode:/vscode.git/clone" did not exist on "574446aec2ef3aec28cac7fef42b3365f1bee906"
Commit
2d250fbf
authored
Jan 28, 2021
by
Rick Ho
Browse files
make test run on nccl version, but fails in correctness
parent
293eef6d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
88 additions
and
46 deletions
+88
-46
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+45
-30
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+3
-0
cuda/moe_fused_kernel.cu
cuda/moe_fused_kernel.cu
+4
-5
fmoe/moe_function.py
fmoe/moe_function.py
+17
-3
setup.py
setup.py
+0
-1
tests/dev_test.sh
tests/dev_test.sh
+2
-0
tests/moe_test.py
tests/moe_test.py
+17
-7
No files found.
cuda/moe_comm_kernel.cu
View file @
2d250fbf
...
@@ -13,50 +13,60 @@
...
@@ -13,50 +13,60 @@
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#include <nccl.h>
void
moe_cuda_expert_exchange_impl
(
void
moe_cuda_expert_exchange_impl
(
const
int
*
local_expert_count
,
const
long
*
local_expert_count
,
int
*
global_expert_count
,
long
*
global_expert_count
,
int
*
fwd_expert_count
,
int
num_expert
,
int
world_size
,
int
num_expert
,
int
world_size
)
{
CudaStreamManager
*
smgr
)
{
MPI_Alltoall
(
local_expert_count
,
num_expert
,
MPI_INT
,
NCCL_SAFE_CALL
(
ncclGroupStart
());
global_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
local_expert_count
+
num_expert
*
i
,
fwd_expert_count
[
i
]
+=
global_expert_count
[
i
+
j
*
num_expert
];
num_expert
,
}
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
num_expert
*
i
,
num_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
)
{
long
num_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
fwe_options
=
torch
::
TensorOptions
()
auto
smgr
=
getCudaStreamManager
(
local_expert_count
.
device
().
index
());
.
dtype
(
local_expert_count
.
dtype
());
auto
fwd_expert_count
=
torch
::
zeros
({
num_expert
},
fwe_options
);
moe_cuda_expert_exchange_impl
(
moe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
fwd
_expert
_count
.
data_ptr
<
int
>
()
,
num
_expert
,
n_workers
,
num_expert
,
n_workers
);
smgr
);
return
{
global_expert_count
,
fwd_expert_count
};
return
{
global_expert_count
};
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_global_scatter_impl
(
void
moe_cuda_global_scatter_impl
(
const
scalar_t
*
local_input_buf
,
const
scalar_t
*
local_input_buf
,
const
int
*
local_expert_count
,
const
long
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
input_buf
,
scalar_t
*
input_buf
,
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
size_t
in_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
// assert world_size > 1
// assert world_size > 1
int
recv_ptr
=
0
;
int
recv_ptr
=
0
;
/* TODO: may save for backward */
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
...
@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
...
@@ -106,8 +116,8 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
"moe_cuda_global_scatter"
,
([
&
]
{
"moe_cuda_global_scatter"
,
([
&
]
{
moe_cuda_global_scatter_impl
<
scalar_t
>
(
moe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
in_feat
,
num_expert
,
n_workers
,
in_feat
,
num_expert
,
n_workers
,
smgr
smgr
...
@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
...
@@ -119,14 +129,14 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_global_gather_impl
(
void
moe_cuda_global_gather_impl
(
const
scalar_t
*
output_buf
,
const
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
long
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
global_expert_count
,
scalar_t
*
local_output_buf
,
scalar_t
*
local_output_buf
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
size_t
out_feat
,
size_t
num_expert
,
size_t
world_size
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
int
send_ptr
=
0
;
long
send_ptr
=
0
;
/* TODO: may save for backward */
/* TODO: may save for backward */
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
long
*
expert_ptr
=
new
long
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
...
@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -176,8 +186,8 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
"moe_cuda_global_gather"
,
([
&
]
{
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_gather_impl
<
scalar_t
>
(
moe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
local_output_buf
.
data_ptr
<
scalar_t
>
(),
out_feat
,
num_expert
,
n_workers
,
out_feat
,
num_expert
,
n_workers
,
smgr
smgr
...
@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -186,4 +196,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
return
{
local_output_buf
,};
return
{
local_output_buf
,};
}
}
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
0
);
smgr
->
ensure
((
void
*
)
&
p
,
t
.
device
());
}
#endif
#endif
cuda/moe_cuda_kernel.h
View file @
2d250fbf
...
@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -41,6 +41,9 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
long
batch_size
,
long
n_workers
);
#include <c10d/ProcessGroupNCCL.hpp>
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
,
torch
::
Tensor
t
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
);
long
num_expert
,
long
n_workers
);
...
...
cuda/moe_fused_kernel.cu
View file @
2d250fbf
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#include "cublas_wrapper.h"
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#include <nccl.h>
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl(
...
@@ -24,8 +23,8 @@ void moe_cuda_global_fused_forward_impl(
scalar_t
*
global_input_buf
,
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
long
*
local_expert_count
,
const
int
*
global_expert_count
,
const
long
*
global_expert_count
,
long
in_feat
,
long
out_feat
,
long
in_feat
,
long
out_feat
,
long
num_expert
,
long
world_size
,
long
num_expert
,
long
world_size
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
...
@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
...
@@ -136,8 +135,8 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
long
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
long
>
(),
in_feat
,
out_feat
,
num_expert
,
n_workers
,
in_feat
,
out_feat
,
num_expert
,
n_workers
,
smgr
);
smgr
);
}));
}));
...
...
fmoe/moe_function.py
View file @
2d250fbf
...
@@ -38,16 +38,30 @@ class MOELocal(Function):
...
@@ -38,16 +38,30 @@ class MOELocal(Function):
class
MOEGlobal
(
Function
):
class
MOEGlobal
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
fmoe_cuda
.
ensure_nccl
(
torch
.
distributed
.
distributed_c10d
.
_default_pg
,
inp
)
num_expert
=
weight
.
shape
[
0
]
num_expert
=
weight
.
shape
[
0
]
local_expert_count
,
pos
=
fmoe_cuda
.
expert_count
(
gate
,
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
world_size
*
num_expert
)
# world_size * num_expert)
global_expert_count
,
fwd_expert_count
=
fmoe_cuda
.
expert_exchange
(
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
]
*
world_size
,
device
=
weight
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
local_expert_count
,
num_expert
,
world_size
)
print
(
'Local {} Global {}'
.
format
(
local_expert_count
,
global_expert_count
))
fwd_expert_count
=
global_expert_count
.
view
(
num_expert
,
world_size
).
sum
(
dim
=
1
).
cpu
()
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
local_input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
local_expert_count
=
local_expert_count
.
cpu
()
global_expert_count
=
global_expert_count
.
cpu
()
local_output_buf
,
global_input_buf
=
fmoe_cuda
.
global_fused_forward
(
local_output_buf
,
global_input_buf
=
fmoe_cuda
.
global_fused_forward
(
local_input_buf
,
weight
,
local_input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
...
...
setup.py
View file @
2d250fbf
...
@@ -8,7 +8,6 @@ cxx_flags = [
...
@@ -8,7 +8,6 @@ cxx_flags = [
]
]
if
os
.
environ
.
get
(
'USE_NCCL'
,
'0'
)
==
'1'
:
if
os
.
environ
.
get
(
'USE_NCCL'
,
'0'
)
==
'1'
:
cxx_flags
.
append
(
'-DMOE_USE_NCCL'
)
cxx_flags
.
append
(
'-DMOE_USE_NCCL'
)
os
.
environ
[
'CXX'
]
=
'mpicxx'
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
setuptools
.
setup
(
setuptools
.
setup
(
...
...
tests/dev_test.sh
View file @
2d250fbf
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
if
[
!
-z
$OMPI_COMM_WORLD_LOCAL_RANK
]
if
[
!
-z
$OMPI_COMM_WORLD_LOCAL_RANK
]
then
then
export
CUDA_VISIBLE_DEVICES
=
$OMPI_COMM_WORLD_LOCAL_RANK
export
CUDA_VISIBLE_DEVICES
=
$OMPI_COMM_WORLD_LOCAL_RANK
export
MASTER_ADDR
=
localhost
export
MASTER_PORT
=
36666
fi
fi
if
[
-z
$OMPI_COMM_WORLD_RANK
]
if
[
-z
$OMPI_COMM_WORLD_RANK
]
...
...
tests/moe_test.py
View file @
2d250fbf
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
time
import
time
import
sys
import
sys
import
os
dev_name_default
=
'cuda:0'
dev_name_default
=
'cuda:0'
...
@@ -105,10 +106,10 @@ def test():
...
@@ -105,10 +106,10 @@ def test():
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
else
:
else
:
weight_array
=
[
torch
.
empty_like
(
moe
.
weight
.
data
)
.
cpu
()
weight_array
=
[
torch
.
empty_like
(
moe
.
weight
.
data
)
for
_
in
range
(
world_size
)]
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
weight_array
,
moe
.
weight
.
data
.
cpu
()
)
torch
.
distributed
.
all_gather
(
weight_array
,
moe
.
weight
.
data
)
moe_raw
.
weight
.
data
=
torch
.
cat
(
weight_array
,
dim
=
0
)
.
cuda
()
moe_raw
.
weight
.
data
=
torch
.
cat
(
weight_array
,
dim
=
0
)
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
randint
(
low
=
0
,
...
@@ -124,13 +125,20 @@ def test():
...
@@ -124,13 +125,20 @@ def test():
if
world_size
>
1
:
if
world_size
>
1
:
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
ou
,
wg
,
lwg
,
lbg
=
raw_out
ou
,
wg
,
lwg
,
lbg
=
raw_out
wg
=
wg
.
cpu
()
torch
.
distributed
.
all_reduce
(
wg
)
torch
.
distributed
.
all_reduce
(
wg
)
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
raw_out
=
ou
,
wg
.
cuda
(),
lwg
,
lbg
raw_out
=
ou
,
wg
,
lwg
,
lbg
else
:
rank
=
0
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'Rank {} {} abs err {}'
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
sys
.
stderr
.
write
(
'=========== moe out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
mo
))
sys
.
stderr
.
write
(
'=========== raw out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
ro
))
return
def
test_dp
():
def
test_dp
():
...
@@ -158,7 +166,9 @@ def test_dp():
...
@@ -158,7 +166,9 @@ def test_dp():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
0
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
1
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
len
(
sys
.
argv
)
>=
2
:
if
len
(
sys
.
argv
)
>=
2
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment