Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
269f3fd4
Commit
269f3fd4
authored
Feb 20, 2021
by
Jiezhong Qiu
Browse files
fix pylint issues
parent
1a6073b5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
8 deletions
+10
-8
fmoe/distributed.py
fmoe/distributed.py
+1
-1
fmoe/layers.py
fmoe/layers.py
+7
-6
fmoe/transformer.py
fmoe/transformer.py
+2
-1
No files found.
fmoe/distributed.py
View file @
269f3fd4
...
...
@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module):
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
for
d
,
s
in
zip
(
datas
,
synced
):
d
.
copy_
(
s
)
def
forward
(
self
,
*
args
,
**
kwargs
):
r
'''
Directly call the module's forward function.
...
...
fmoe/layers.py
View file @
269f3fd4
r
'''
Layers that FMoE provides to users
'''
import
math
import
torch
import
torch.nn
as
nn
import
numpy
as
np
import
math
from
.functions
import
moe_prepare_forward
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
...
...
@@ -34,17 +34,18 @@ class FMoELinear(nn.Module):
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
# copied from
https://pytorch.org/docs/stable/nn.init.html#
torch.nn.init.kaiming_uniform_
# copied from torch.nn.init.kaiming_uniform_
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
'fan_in'
)
gain
=
nn
.
init
.
calculate_gain
(
'leaky_relu'
,
math
.
sqrt
(
5
))
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
...
...
fmoe/transformer.py
View file @
269f3fd4
...
...
@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
mp_rank
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
mark_parallel_comm
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment