Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
069cf01a
Commit
069cf01a
authored
Jan 11, 2021
by
Rick Ho
Browse files
make moe run with cuda
parent
a4f7f1da
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
23 deletions
+74
-23
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+7
-0
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+29
-0
pytorch/cuda/moe_cuda_kernel.h
pytorch/cuda/moe_cuda_kernel.h
+4
-0
pytorch/cuda/moe_function.py
pytorch/cuda/moe_function.py
+22
-19
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+11
-3
pytorch/cuda/run.sh
pytorch/cuda/run.sh
+1
-1
No files found.
pytorch/cuda/moe.cpp
View file @
069cf01a
...
...
@@ -67,6 +67,12 @@ std::vector<torch::Tensor> moe_backward(
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_expert_exchange
(
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
...
...
@@ -107,6 +113,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"local_scatter"
,
&
moe_local_scatter
,
"MoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
moe_local_gather
,
"MoE local gather (CUDA)"
);
#ifdef MOE_USE_NCCL
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
#endif
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
069cf01a
...
...
@@ -82,6 +82,35 @@ void moe_cuda_expert_count_impl(
#ifdef MOE_USE_NCCL
void
moe_cuda_expert_exchange_impl
(
const
int
*
local_expert_count
,
int
*
global_expert_count
,
int
*
fwd_expert_count
,
int
num_expert
,
int
world_size
)
{
MPI_Alltoall
(
local_expert_count
,
num_expert
,
MPI_INT
,
global_expert_count
,
num_expert
,
MPI_INT
,
MPI_COMM_WORLD
);
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
fwd_expert_count
[
i
]
+=
global_expert_count
[
i
+
j
*
num_expert
];
}
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
)
{
auto
global_expert_count
=
torch
::
empty_like
(
local_expert_count
);
auto
fwe_options
=
torch
::
TensorOptions
()
.
dtype
(
local_expert_count
.
dtype
());
auto
fwd_expert_count
=
torch
::
zeros
({
num_expert
},
fwe_options
);
moe_cuda_expert_exchange_impl
(
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
fwd_expert_count
.
data_ptr
<
int
>
(),
num_expert
,
n_workers
);
return
{
global_expert_count
,
fwd_expert_count
};
}
template
<
typename
scalar_t
>
void
moe_cuda_global_scatter_impl
(
const
scalar_t
*
local_input_buf
,
...
...
pytorch/cuda/moe_cuda_kernel.h
View file @
069cf01a
...
...
@@ -41,6 +41,10 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
);
#endif
#endif // MOE_CUDA_KERNEL_H
pytorch/cuda/moe_function.py
View file @
069cf01a
...
...
@@ -35,53 +35,56 @@ class MOEGlobal(Function):
local_expert_count
,
pos
=
moe_cuda
.
expert_count
(
gate
,
world_size
*
num_expert
)
global_expert_count
=
torch
.
empty_like
(
world_size
,
num_expert
)
torch
.
distributed
.
all_to_all
(
global_expert_count
,
local_expert_count
.
reshape
(
world_size
,
num_expert
))
batch_size
=
int
(
global_expert_count
.
sum
().
item
())
global_expert_count
,
fwd_expert_count
=
moe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
global_input_buf
,
=
moe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
world_size
)
fwd_
batch_size
,
world_size
)
global_output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
global_output_buf
,
=
moe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
local_output_buf
,
=
moe_cuda
.
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
output
=
moe_cuda
.
local_gather
(
local_output_buf
,
pos
)
output
,
=
moe_cuda
.
local_gather
(
local_output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
pos
,
num_expert
,
batch_size
,
world_size
]
variables
=
(
global_input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
pos
)
ctx
.
moe_args
=
(
num_expert
,
inp
.
shape
[
0
],
fwd_batch_size
,
world_size
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
pos
,
num_expert
,
batch_size
,
world_size
)
=
ctx
.
saved_tensors
(
input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
pos
)
=
ctx
.
saved_tensors
num_expert
,
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
,
=
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
global_grad_out_buf
,
=
moe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
world_size
)
fwd_
batch_size
,
world_size
)
grad_inp_buf
,
grad_weight
=
moe_cuda
.
backward
(
global_grad_out_buf
,
input_buf
,
weight
,
expert_count
)
global_grad_out_buf
,
input_buf
,
weight
,
fwd_
expert_count
)
local_grad_inp_buf
=
moe_cuda
.
global_gather
(
grad_inp_buf
,
local_grad_inp_buf
,
=
moe_cuda
.
global_gather
(
grad_inp_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
world_size
)
local_
batch_size
,
world_size
)
grad_inp
,
=
moe_cuda
.
local_gather
(
local_grad_inp_buf
,
pos
)
return
grad_inp
,
None
,
grad_weight
return
grad_inp
,
None
,
grad_weight
,
None
def
moe
(
inp
,
gate
,
weight
,
world_size
):
if
world_size
is
not
None
:
return
MOEGlobal
.
apply
(
inp
,
gate
,
weight
)
return
MOEGlobal
.
apply
(
inp
,
gate
,
weight
,
world_size
)
else
:
return
MOELocal
.
apply
(
inp
,
gate
,
weight
)
pytorch/cuda/moe_test.py
View file @
069cf01a
...
...
@@ -82,18 +82,23 @@ def test():
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
if
world_size
>
1
:
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
else
:
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
*
torch
.
distributed
.
get_
world_size
()
,
size
=
(
batch_size
,
),
high
=
num_expert
*
world_size
,
size
=
(
batch_size
,),
requires_grad
=
False
).
int
().
cuda
()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
print
(
'hhh'
)
return
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
...
...
@@ -128,6 +133,9 @@ def test_dp():
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
world_size
=
torch
.
distributed
.
get_world_size
()
if
world_size
==
1
:
world_size
=
None
test
()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# perf()
pytorch/cuda/run.sh
View file @
069cf01a
...
...
@@ -8,7 +8,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
if
[
-z
$1
]
then
python3 moe_test.py
python3 moe_test.py
2>logs/
$OMPI_COMM_WORLD_RANK
.log
elif
[
.
$1
=
'.test_all'
]
then
for
nexp
in
1 2 4
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment