Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
881b10c2
Commit
881b10c2
authored
Jan 01, 2021
by
Rick Ho
Browse files
make it run
parent
691e92e1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+3
-3
No files found.
pytorch/cuda/moe_cuda_kernel.cu
View file @
881b10c2
...
...
@@ -20,9 +20,9 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#define MOE_DEBUG
//
#define MOE_DEBUG
// #define MOE_BREAKDOWN
#define MOE_DEBUG_SCATTER
//
#define MOE_DEBUG_SCATTER
template
<
typename
scalar_t
>
__global__
...
...
@@ -114,7 +114,7 @@ void moe_cuda_forward_impl(
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
for
(
int
i
=
batch_size
-
1
;
i
>
0
;
--
i
)
{
for
(
int
i
=
tot_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
expert_ptr
[
0
]
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment