Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
284f1424
Commit
284f1424
authored
Jan 09, 2021
by
Rick Ho
Browse files
degrade to single fc fwd
parent
d690c7b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
71 deletions
+33
-71
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+4
-7
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+19
-34
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+10
-30
No files found.
pytorch/cuda/moe.cpp
View file @
284f1424
...
@@ -7,8 +7,7 @@
...
@@ -7,8 +7,7 @@
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight1
,
torch
::
Tensor
weight
);
torch
::
Tensor
weight2
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
grad_output
,
...
@@ -26,19 +25,17 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -26,19 +25,17 @@ std::vector<torch::Tensor> moe_cuda_backward(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight1
,
// [num_expert x hidden_feat x in_feat]
torch
::
Tensor
weight
// [num_expert x hidden_feat x in_feat]
torch
::
Tensor
weight2
// [num_expert x out_feat x hidden_feat]
)
{
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight1
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight2
);
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
return
moe_cuda_forward
(
input
,
gate
,
weight
1
,
weight2
);
return
moe_cuda_forward
(
input
,
gate
,
weight
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
...
...
pytorch/cuda/moe.py
View file @
284f1424
...
@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42)
...
@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42)
class
MOEFunction
(
Function
):
class
MOEFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
1
,
weight2
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
# out_feat, in_feat = weight.size()[1:]
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
1
,
weight2
)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
)
variables
=
[
inp
,
gate
,
weight
1
,
weight2
]
variables
=
[
inp
,
gate
,
weight
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
...
@@ -32,59 +32,46 @@ class MOEFunction(Function):
...
@@ -32,59 +32,46 @@ class MOEFunction(Function):
class
MOELayer
(
nn
.
Module
):
class
MOELayer
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
MOELayer
,
self
).
__init__
()
super
(
MOELayer
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
weight1
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight1
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
,
gate
):
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
1
,
self
.
weight2
)
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
)
class
MOELayer_raw
(
nn
.
Module
):
class
MOELayer_raw
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
MOELayer_raw
,
self
).
__init__
()
super
(
MOELayer_raw
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
weight1
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden_feat
)
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
print
(
linear
.
weight
.
shape
)
# print(linear.weight.shape)
self
.
weight1
.
data
[
i
]
=
linear
.
weight
.
data
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
,
gate
):
gate_long
=
gate
.
long
()
gate_long
=
gate
.
long
()
batch_size
=
inp
.
size
(
0
)
batch_size
=
inp
.
size
(
0
)
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
print
(
self
.
weight2
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
hid
=
inp
[
i
]
@
self
.
weight1
[
gate_long
[
i
]].
t
()
x
[
i
]
=
inp
[
i
]
@
self
.
weight
[
gate_long
[
i
]].
t
()
print
(
hid
)
x
[
i
]
=
hid
@
self
.
weight2
[
gate_long
[
i
]].
t
()
return
x
return
x
...
@@ -105,15 +92,13 @@ def test():
...
@@ -105,15 +92,13 @@ def test():
batch_size
=
4
batch_size
=
4
num_expert
=
2
num_expert
=
2
in_feat
=
6
in_feat
=
6
hidden_feat
=
12
out_feat
=
7
out_feat
=
7
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
.
weight1
.
data
=
moe
.
weight1
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight2
.
data
=
moe
.
weight2
.
data
.
clone
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
284f1424
...
@@ -58,12 +58,10 @@ template <typename scalar_t>
...
@@ -58,12 +58,10 @@ template <typename scalar_t>
void
moe_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
const
int
*
d_gate
,
const
int
*
d_gate
,
const
scalar_t
*
weight1
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight2
,
scalar_t
*
output
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
hidden_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
const
size_t
num_expert
)
{
...
@@ -73,14 +71,12 @@ void moe_cuda_forward_impl(
...
@@ -73,14 +71,12 @@ void moe_cuda_forward_impl(
timestamp
(
t_init
);
timestamp
(
t_init
);
#endif
#endif
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
scalar_t
*
input_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
out_feat
));
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
hidden_feat
));
#ifdef MOE_BREAKDOWN
#ifdef MOE_BREAKDOWN
timestamp
(
t_malloc
);
timestamp
(
t_malloc
);
...
@@ -152,22 +148,11 @@ void moe_cuda_forward_impl(
...
@@ -152,22 +148,11 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
hidden
_feat
,
expert_count
[
i
],
in_feat
,
out
_feat
,
expert_count
[
i
],
in_feat
,
&
alpha
,
&
alpha
,
weight
1
+
i
*
in_feat
*
hidden
_feat
,
in_feat
,
weight
+
i
*
in_feat
*
out
_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
&
beta
,
hidden_buf
+
hidden_feat
*
ptr
,
hidden_feat
));
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
hidden_feat
,
&
alpha
,
weight2
+
i
*
hidden_feat
*
out_feat
,
hidden_feat
,
hidden_buf
+
hidden_feat
*
ptr
,
hidden_feat
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
output_buf
+
out_feat
*
ptr
,
out_feat
));
));
...
@@ -195,7 +180,6 @@ void moe_cuda_forward_impl(
...
@@ -195,7 +180,6 @@ void moe_cuda_forward_impl(
#endif
#endif
cudaFree
(
input_buf
);
cudaFree
(
input_buf
);
cudaFree
(
hidden_buf
);
cudaFree
(
output_buf
);
cudaFree
(
output_buf
);
cudaFree
(
d_pos
);
cudaFree
(
d_pos
);
delete
[]
pos
;
delete
[]
pos
;
...
@@ -244,17 +228,15 @@ void moe_cuda_grad_weight(
...
@@ -244,17 +228,15 @@ void moe_cuda_grad_weight(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight1
,
torch
::
Tensor
weight
torch
::
Tensor
weight2
)
{
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
num_expert
=
weight1
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight2
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
hidden_feat
=
weight1
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight1
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld,
hidden_feat = %ld,
out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
hidden_feat
,
out_feat
);
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
...
@@ -262,12 +244,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -262,12 +244,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
gate
.
data_ptr
<
int
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
hidden_feat
,
out_feat
,
out_feat
,
num_expert
num_expert
);
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment