Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
861b75c1
Commit
861b75c1
authored
Jan 10, 2021
by
Rick Ho
Browse files
remove weight input to c expert count function
parent
60b93e39
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
12 deletions
+10
-12
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+4
-5
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+1
-1
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+2
-3
pytorch/cuda/run.sh
pytorch/cuda/run.sh
+3
-3
No files found.
pytorch/cuda/moe.cpp
View file @
861b75c1
...
@@ -5,8 +5,7 @@
...
@@ -5,8 +5,7 @@
#include <vector>
#include <vector>
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
weight
,
// TODO: pass num-experts in another way?
torch
::
Tensor
gate
,
size_t
num_expert
);
torch
::
Tensor
gate
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
...
@@ -35,10 +34,10 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -35,10 +34,10 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
weight
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
)
{
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
weight
,
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
...
...
pytorch/cuda/moe.py
View file @
861b75c1
...
@@ -11,7 +11,7 @@ class MOEFunction(Function):
...
@@ -11,7 +11,7 @@ class MOEFunction(Function):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
# out_feat, in_feat = weight.size()[1:]
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
expert_count
,
pos
=
moe_cuda
.
expert_count
(
weight
,
gate
)
expert_count
,
pos
=
moe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
]
)
input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
861b75c1
...
@@ -199,10 +199,9 @@ void moe_cuda_backward_impl(
...
@@ -199,10 +199,9 @@ void moe_cuda_backward_impl(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
weight
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
)
{
size_t
num_expert
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
const
auto
batch_size
=
gate
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
...
...
pytorch/cuda/run.sh
View file @
861b75c1
...
@@ -3,7 +3,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
...
@@ -3,7 +3,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
export
LD_LIBRARY_PATH
=
/home/laekov/.local/lib/python3.7/site-packages/torch/lib:
$LD_LIBRARY_PATH
if
[
-z
$1
]
if
[
-z
$1
]
then
then
python moe.py
python
3
moe.py
elif
[
.
$1
=
'.test_all'
]
elif
[
.
$1
=
'.test_all'
]
then
then
for
nexp
in
1 2 4
for
nexp
in
1 2 4
...
@@ -15,11 +15,11 @@ then
...
@@ -15,11 +15,11 @@ then
for
bs
in
4 16 64 256 512 1024 2048 4096
for
bs
in
4 16 64 256 512 1024 2048 4096
do
do
echo
$bs
$nexp
${
inf
}
x
${
ouf
}
echo
$bs
$nexp
${
inf
}
x
${
ouf
}
python moe_test.py
$bs
$inf
$ouf
$nexp
python
3
moe_test.py
$bs
$inf
$ouf
$nexp
done
done
done
done
done
done
done
done
else
else
python
$@
python
3
$@
fi
fi
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment