Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
92f1774a
"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "42067ef2628320aa28cc79eb7d8bca97088f934e"
Commit
92f1774a
authored
Jan 09, 2021
by
Rick Ho
Browse files
moe backward (cannot pass test)
parent
c91dfad8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
93 deletions
+84
-93
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+12
-13
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+9
-12
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+63
-68
No files found.
pytorch/cuda/moe.cpp
View file @
92f1774a
...
@@ -22,10 +22,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -22,10 +22,10 @@ std::vector<torch::Tensor> moe_cuda_forward(
torch
::
Tensor
expert_count
);
torch
::
Tensor
expert_count
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
torch
::
Tensor
grad_output
_buf
,
torch
::
Tensor
input
,
torch
::
Tensor
input
_buf
,
torch
::
Tensor
gate
,
torch
::
Tensor
weight
,
torch
::
Tensor
weigh
t
);
torch
::
Tensor
expert_coun
t
);
// C++ interface
// C++ interface
...
@@ -58,7 +58,7 @@ std::vector<torch::Tensor> moe_local_gather(
...
@@ -58,7 +58,7 @@ std::vector<torch::Tensor> moe_local_gather(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
weight
,
// [num_expert x
hidden
_feat x in_feat]
torch
::
Tensor
weight
,
// [num_expert x
out
_feat x in_feat]
torch
::
Tensor
expert_count
// [batch_size]
torch
::
Tensor
expert_count
// [batch_size]
)
{
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
...
@@ -72,21 +72,20 @@ std::vector<torch::Tensor> moe_forward(
...
@@ -72,21 +72,20 @@ std::vector<torch::Tensor> moe_forward(
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
input
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size
]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat
]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
)
{
)
{
CHECK_INPUT
(
grad_output
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
/*
/*
The bias term should have been merged into weight. Note the following fact that
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
Wx+b = [W b] [x]
[1]
[1]
*/
*/
return
moe_cuda_backward
(
grad_output
,
input
,
gate
,
weight
);
return
moe_cuda_backward
(
grad_output
_buf
,
input
_buf
,
weight
,
expert_count
);
}
}
...
...
pytorch/cuda/moe.py
View file @
92f1774a
...
@@ -16,21 +16,21 @@ class MOEFunction(Function):
...
@@ -16,21 +16,21 @@ class MOEFunction(Function):
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
inp
,
gate
,
weight
,
expert_count
,
pos
]
variables
=
[
inp
ut_buf
,
gate
,
weight
,
expert_count
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
# print("grad_out", grad_out)
input_buf
,
gate
,
weight
,
expert_count
,
pos
=
ctx
.
saved_tensors
# print("input", ctx.saved_tensors[0])
grad_
inp
,
grad_weight
=
moe_cuda
.
backward
(
grad_
out_buf
,
=
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_
out
.
contiguous
(),
*
ctx
.
saved_tensors
)
grad_
inp_buf
,
grad_weight
=
moe_cuda
.
backward
(
out_feat
,
in_feat
=
grad_weight
.
size
()[
1
:]
grad_out_buf
,
input_buf
,
weight
,
expert_count
)
# print("grad_weight_column_major", grad_weight.flatten()
)
grad_inp
,
=
moe_cuda
.
local_gather
(
grad_inp_buf
,
pos
)
grad_weight_row_major
=
grad_weight
.
view
(
-
1
,
in_feat
,
out_feat
).
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
return
grad_inp
,
None
,
grad_weight
_row_major
return
grad_inp
,
None
,
grad_weight
class
MOELayer
(
nn
.
Module
):
class
MOELayer
(
nn
.
Module
):
...
@@ -82,9 +82,6 @@ def test_module(moe, linear, inp, gate):
...
@@ -82,9 +82,6 @@ def test_module(moe, linear, inp, gate):
moe
.
zero_grad
()
moe
.
zero_grad
()
x
=
linear
(
inp
)
x
=
linear
(
inp
)
output
=
moe
(
x
,
gate
)
output
=
moe
(
x
,
gate
)
print
(
output
)
return
output
print
(
output
)
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
92f1774a
...
@@ -124,9 +124,7 @@ void moe_cuda_forward_impl(
...
@@ -124,9 +124,7 @@ void moe_cuda_forward_impl(
scalar_t
*
output_buf
,
scalar_t
*
output_buf
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
)
{
cublasOperation_t
transb
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
...
@@ -151,40 +149,55 @@ void moe_cuda_forward_impl(
...
@@ -151,40 +149,55 @@ void moe_cuda_forward_impl(
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_grad_weight
(
void
moe_cuda_backward_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
grad_output_buf
,
const
int
*
gate
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
grad_output
,
const
scalar_t
*
weight
,
scalar_t
*
grad_weight
,
// [num_expert x out_feat x in_feat]
const
int
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
const
size_t
num_expert
)
{
ENSURE_SMGR
(
smgr
,
num_expert
);
scalar_t
alpha
=
1
,
beta
=
0
;
int
*
gate_host
=
new
int
[
batch_size
];
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
scalar_t
alpha
=
1
,
beta
=
1
;
if
(
expert_count
[
i
]
==
0
)
{
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
// checkCudaErrors(cublasSetStream);
continue
;
checkCudaErrors
(
cublasXgemm
(
smgr
.
handles
[
0
],
}
CUBLAS_OP_N
,
// Use T(B) x T(A) = T(C) to produce row-major C
CUBLAS_OP_T
,
out_feat
,
// Backward input: g_i = w @ g_o
in_feat
,
checkCudaErrors
(
cublasXgemm
(
smgr
.
handles
[
i
],
1
,
CUBLAS_OP_N
,
&
alpha
,
CUBLAS_OP_N
,
grad_output
+
i
*
out_feat
,
in_feat
,
expert_count
[
i
],
out_feat
,
out_feat
,
&
alpha
,
input
+
i
*
in_feat
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
&
beta
,
grad_weight
+
gate_host
[
i
]
*
out_feat
*
in_feat
,
grad_input_buf
+
in_feat
*
ptr
,
in_feat
out_feat
));
));
}
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
// Backward weight: g_w = i @ g_o
checkCudaErrors
(
cudaStreamSynchronize
(
*
(
smgr
.
streams
+
i
)));
checkCudaErrors
(
cublasXgemm
(
smgr
.
handles
[
i
],
}
CUBLAS_OP_N
,
delete
[]
gate_host
;
CUBLAS_OP_T
,
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
input_buf
+
in_feat
*
ptr
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
ptr
+=
expert_count
[
i
];
}
smgr
.
sync
();
}
}
...
@@ -285,8 +298,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -285,8 +298,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
CUBLAS_OP_T
);
);
}));
}));
...
@@ -294,49 +306,32 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -294,49 +306,32 @@ std::vector<torch::Tensor> moe_cuda_forward(
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input
,
// [batch_size x out_feat]
torch
::
Tensor
input
_buf
,
// [batch_size x out_feat]
torch
::
Tensor
gate
,
// [batch_size
]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat
]
torch
::
Tensor
weight
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
)
{
)
{
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
input
_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
const
int
device
=
device_of
(
input
).
value
().
index
();
if
(
smgr
.
streams
==
NULL
)
{
smgr
.
setup
(
num_expert
,
device
);
}
auto
grad_input
=
grad_output
.
new_
zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_input
_buf
=
grad_output
_buf
.
new_
empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output
.
new_
zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
auto
grad_weight
=
grad_output
_buf
.
new_
empty
({
num_expert
,
out_feat
,
in_feat
});
// grad_input is easy to compute, exactly the same as forward
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
/* TODO: Backward currently brokenn
moe_cuda_backward_impl
<
scalar_t
>
(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
moe_cuda_forward_impl<scalar_t>(
input_buf
.
data_ptr
<
scalar_t
>
(),
grad_output.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
grad_input.data_ptr<scalar_t>(),
expert_count
.
data_ptr
<
int
>
(),
batch_size,
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
out_feat,
in_feat,
num_expert,
CUBLAS_OP_N
);
}));
*/
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_grad_weight
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
grad_output
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
...
@@ -345,7 +340,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -345,7 +340,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
);
);
}));
}));
return
{
grad_input
,
grad_weight
};
return
{
grad_input
_buf
,
grad_weight
};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment