Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
35addec6
Commit
35addec6
authored
Dec 30, 2020
by
Rick Ho
Browse files
two-level matmul fix transpose
parent
191c1e46
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
140 additions
and
49 deletions
+140
-49
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+9
-6
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+42
-23
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+89
-20
No files found.
pytorch/cuda/moe.cpp
View file @
35addec6
...
...
@@ -7,13 +7,14 @@
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
);
torch
::
Tensor
weight1
,
torch
::
Tensor
weight2
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
);
torch
::
Tensor
weight
);
// C++ interface
...
...
@@ -25,17 +26,19 @@ std::vector<torch::Tensor> moe_cuda_backward(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input
,
// [batch_size x in_feat]
torch
::
Tensor
gate
,
// [batch_size]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
torch
::
Tensor
weight1
,
// [num_expert x hidden_feat x in_feat]
torch
::
Tensor
weight2
// [num_expert x out_feat x hidden_feat]
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight1
);
CHECK_INPUT
(
weight2
);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return
moe_cuda_forward
(
input
,
gate
,
weight
);
return
moe_cuda_forward
(
input
,
gate
,
weight
1
,
weight2
);
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
...
...
@@ -71,4 +74,4 @@ int main() {
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
}
\ No newline at end of file
}
pytorch/cuda/moe.py
View file @
35addec6
...
...
@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42)
class
MOEFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
out_feat
,
in_feat
=
weight
.
size
()[
1
:]
weight_column_major
=
weight
.
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
_column_major
)
variables
=
[
inp
,
gate
,
weight
_column_major
]
def
forward
(
ctx
,
inp
,
gate
,
weight
1
,
weight2
):
#
out_feat, in_feat = weight.size()[1:]
#
weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output
=
moe_cuda
.
forward
(
inp
,
gate
,
weight
1
,
weight2
)
variables
=
[
inp
,
gate
,
weight
1
,
weight2
]
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
...
...
@@ -32,45 +32,59 @@ class MOEFunction(Function):
class
MOELayer
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out
_feat
=
4096
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden
_feat
=
4096
,
out_feat
=
1024
):
super
(
MOELayer
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight1
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden_feat
)
self
.
weight1
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
)
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
1
,
self
.
weight2
)
class
MOELayer_raw
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out
_feat
=
4096
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden
_feat
=
4096
,
out_feat
=
1024
):
super
(
MOELayer_raw
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
hidden_feat
=
hidden_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight1
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
hidden_feat
,
in_feat
))
self
.
weight2
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
hidden_feat
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
hidden_feat
)
print
(
linear
.
weight
.
shape
)
self
.
weight1
.
data
[
i
]
=
linear
.
weight
.
data
linear
=
nn
.
Linear
(
in_features
=
self
.
hidden_feat
,
out_features
=
self
.
out_feat
)
self
.
weight2
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
gate_long
=
gate
.
long
()
batch_size
=
inp
.
size
(
0
)
x
=
inp
.
new_zeros
((
batch_size
,
self
.
out_feat
))
print
(
self
.
weight2
)
for
i
in
range
(
batch_size
):
x
[
i
]
=
self
.
weight
[
gate_long
[
i
]]
@
inp
[
i
]
hid
=
inp
[
i
]
@
self
.
weight1
[
gate_long
[
i
]].
t
()
print
(
hid
)
x
[
i
]
=
hid
@
self
.
weight2
[
gate_long
[
i
]].
t
()
return
x
...
...
@@ -80,6 +94,8 @@ def test_module(moe, linear, inp, gate):
x
=
linear
(
inp
)
output
=
moe
(
x
,
gate
)
print
(
output
)
return
output
print
(
output
)
y
=
output
.
mean
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
...
...
@@ -87,15 +103,17 @@ def test_module(moe, linear, inp, gate):
def
test
():
batch_size
=
4
num_expert
=
4
in_feat
=
2
out_feat
=
3
num_expert
=
2
in_feat
=
6
hidden_feat
=
12
out_feat
=
7
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
hidden_feat
,
out_feat
).
cuda
()
moe_raw
.
weight1
.
data
=
moe
.
weight1
.
data
.
clone
()
moe_raw
.
weight2
.
data
=
moe
.
weight2
.
data
.
clone
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
...
...
@@ -104,6 +122,7 @@ def test():
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
]
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
35addec6
...
...
@@ -10,17 +10,20 @@
#include <cublas_v2.h>
#include <helper_cuda.h>
//
#include "timer.hh"
#include "timer.hh"
#include "cublas_wrapper.h"
#include "cuda_stream_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
#define MOE_DEBUG
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
int
*
offset
,
const
scalar_t
**
ptrs
)
{
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
int
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
...
...
@@ -32,22 +35,35 @@ template <typename scalar_t>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input
,
const
int
*
d_gate
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight1
,
const
scalar_t
*
weight2
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
hidden_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
const
size_t
num_expert
)
{
auto
h
=
getCudaStreamManager
(
num_expert
);
scalar_t
*
input_buf
,
*
output_buf
;
#ifdef MOE_BREAKDOWN
timestamp
(
t_init
);
#endif
scalar_t
*
input_buf
,
*
hidden_buf
,
*
output_buf
;
checkCudaErrors
(
cudaMalloc
(
&
input_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
in_feat
));
checkCudaErrors
(
cudaMalloc
(
&
output_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
out_feat
));
checkCudaErrors
(
cudaMalloc
(
&
hidden_buf
,
sizeof
(
scalar_t
)
*
batch_size
*
hidden_feat
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_malloc
);
fprintf
(
stderr
,
"Malloc time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_malloc
)
*
1e6
);
#endif
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_count
=
new
int
[
num_expert
],
*
expert_ptr
=
new
int
[
num_expert
];
...
...
@@ -55,6 +71,13 @@ void moe_cuda_forward_impl(
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
#ifdef MOE_BREAKDOWN
timestamp
(
t_cpy
);
fprintf
(
stderr
,
"Copy time %.3lf us
\n
"
,
getDuration
(
t_malloc
,
t_cpy
)
*
1e6
);
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
}
...
...
@@ -62,6 +85,13 @@ void moe_cuda_forward_impl(
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
#ifdef MOE_BREAKDOWN
timestamp
(
t_expert
);
fprintf
(
stderr
,
"Expert asn time %.3lf us
\n
"
,
getDuration
(
t_cpy
,
t_expert
)
*
1e6
);
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
target_idx
=
expert_ptr
[
gate
[
i
]]
++
;
#ifdef MOE_DEBUG_SCATTER
...
...
@@ -73,6 +103,13 @@ void moe_cuda_forward_impl(
h
->
getStream
(
gate
[
i
])));
}
#ifdef MOE_BREAKDOWN
h
->
sync
();
timestamp
(
t_scatter
);
fprintf
(
stderr
,
"Scatter time %.3lf us
\n
"
,
getDuration
(
t_expert
,
t_scatter
)
*
1e6
);
#endif
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
...
...
@@ -86,19 +123,37 @@ void moe_cuda_forward_impl(
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
(
transb
==
CUBLAS_OP_T
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out
_feat
,
expert_count
[
i
],
in_feat
,
hidden
_feat
,
expert_count
[
i
],
in_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
(
transb
==
CUBLAS_OP_T
)
?
out_feat
:
in_feat
,
weight1
+
i
*
in_feat
*
hidden_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
hidden_buf
+
hidden_feat
*
ptr
,
hidden_feat
));
checkCudaErrors
(
cublasXgemm
(
h
->
getHandle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
hidden_feat
,
&
alpha
,
weight2
+
i
*
hidden_feat
*
out_feat
,
hidden_feat
,
hidden_buf
+
hidden_feat
*
ptr
,
hidden_feat
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
));
ptr
+=
expert_count
[
i
];
}
#ifdef MOE_BREAKDOWN
h
->
sync
();
timestamp
(
t_mm
);
fprintf
(
stderr
,
"GeMM time %.3lf us
\n
"
,
getDuration
(
t_scatter
,
t_mm
)
*
1e6
);
#endif
for
(
int
i
=
batch_size
-
1
;
i
>=
0
;
--
i
)
{
int
target_idx
=
--
expert_ptr
[
gate
[
i
]];
#ifdef MOE_DEBUG_SCATTER
...
...
@@ -113,6 +168,14 @@ void moe_cuda_forward_impl(
h
->
sync
();
#ifdef MOE_BREAKDOWN
timestamp
(
t_gather
);
fprintf
(
stderr
,
"Gather time %.3lf us
\n
"
,
getDuration
(
t_mm
,
t_gather
)
*
1e6
);
fprintf
(
stderr
,
"Overall time %.3lf us
\n
"
,
getDuration
(
t_init
,
t_gather
)
*
1e6
);
#endif
cudaFree
(
input_buf
);
cudaFree
(
output_buf
);
}
...
...
@@ -159,14 +222,17 @@ void moe_cuda_grad_weight(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
)
{
torch
::
Tensor
weight1
,
torch
::
Tensor
weight2
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
num_expert
=
weight1
.
size
(
0
);
const
auto
out_feat
=
weight2
.
size
(
1
);
const
auto
hidden_feat
=
weight1
.
size
(
1
);
const
auto
in_feat
=
weight1
.
size
(
2
);
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld,
hidden_feat = %ld,
out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
hidden_feat
,
out_feat
);
#endif
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
...
...
@@ -174,13 +240,14 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
hidden_feat
,
out_feat
,
num_expert
,
CUBLAS_OP_T
num_expert
);
}));
...
...
@@ -205,6 +272,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
// grad_input is easy to compute, exactly the same as forward
/* TODO: Backward currently brokenn
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(),
...
...
@@ -218,6 +286,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_grad_weight
<
scalar_t
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment