Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
8bac18dc
Commit
8bac18dc
authored
Mar 23, 2021
by
TiagoMAntunes
Browse files
Updated arguments for MOELinear.apply
parent
a3b2eb62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
fmoe/functions.py
fmoe/functions.py
+4
-4
fmoe/layers.py
fmoe/layers.py
+1
-1
No files found.
fmoe/functions.py
View file @
8bac18dc
...
@@ -110,21 +110,21 @@ class MOELinear(Function):
...
@@ -110,21 +110,21 @@ class MOELinear(Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
global_input_buf
,
weight
,
fwd_expert_count
)
)
variables
=
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
return
global_output_buf
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
(
input_buf
,
fwd_expert_count
,
weight
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
)
return
grad_inp_buf
,
grad_weight
,
None
return
grad_inp_buf
,
None
,
grad_weight
class
MOEGather
(
Function
):
class
MOEGather
(
Function
):
...
...
fmoe/layers.py
View file @
8bac18dc
...
@@ -41,7 +41,7 @@ class FMoELinear(nn.Module):
...
@@ -41,7 +41,7 @@ class FMoELinear(nn.Module):
r
"""
r
"""
Call MOE function
Call MOE function
"""
"""
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# instability in backward, leading to incorrect
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment