Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3c24222c
"vscode:/vscode.git/clone" did not exist on "9e125514e39a1f4281f822e6bcbf235310dc89b6"
Commit
3c24222c
authored
Feb 21, 2021
by
Jiezhong Qiu
Browse files
add and initialize bias term in FMoELinear
parent
406955e7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
7 deletions
+19
-7
fmoe/layers.py
fmoe/layers.py
+19
-7
No files found.
fmoe/layers.py
View file @
3c24222c
...
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance.
The FMoELinear module provides such function.
'''
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
rank
=
0
):
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
):
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -41,17 +46,24 @@ class FMoELinear(nn.Module):
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
Call MOE function
'''
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
,
dim
=
0
)
x
=
x
+
bias
return
x
def
mark_module_parallel_comm
(
module
,
comm
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment