Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
93291a7e
Commit
93291a7e
authored
Dec 29, 2020
by
Jiezhong Qiu
Browse files
update
parent
c5d719cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
+16
-3
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+16
-3
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
93291a7e
...
...
@@ -197,8 +197,7 @@ void moe_cuda_grad_weight(
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
const
size_t
num_expert
)
{
Helper
*
h
=
getHelper
(
num_expert
);
...
...
@@ -207,7 +206,7 @@ void moe_cuda_grad_weight(
checkCudaErrors
(
cudaMemcpy
(
gate_host
,
gate
,
batch_size
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
));
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
+
gate_host
[
i
])));
checkCudaErrors
(
cublas
S
gemm
(
h
->
handle
,
checkCudaErrors
(
cublas
X
gemm
(
h
->
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
...
...
@@ -283,6 +282,20 @@ std::vector<torch::Tensor> moe_cuda_backward(
CUBLAS_OP_N
);
}));
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_grad_weight
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gate
.
data_ptr
<
int
>
(),
grad_output
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
in_feat
,
num_expert
);
}));
return
{
grad_input
,
grad_weight
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment