Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
da4a8e5e
Commit
da4a8e5e
authored
Feb 20, 2021
by
Jiezhong Qiu
Browse files
fix import error and dtype/device problem in MoE
parent
5f5ccd47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
1 deletion
+8
-1
fmoe/layers.py
fmoe/layers.py
+8
-1
No files found.
fmoe/layers.py
View file @
da4a8e5e
...
...
@@ -3,6 +3,8 @@ Layers that FMoE provides to users
'''
import
torch
import
torch.nn
as
nn
import
numpy
as
np
import
math
from
.functions
import
moe_prepare_forward
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
...
...
@@ -31,13 +33,18 @@ class FMoELinear(nn.Module):
Initialize the weight as linear layers
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
# copied from https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
'fan_in'
)
gain
=
nn
.
init
.
calculate_gain
(
'leaky_relu'
,
math
.
sqrt
(
5
))
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
from_numpy
(
weight
)
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment