Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
94eca783
Commit
94eca783
authored
Feb 24, 2021
by
Rick Ho
Browse files
reset parameters in megatron
parent
f6afdbee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
25 deletions
+34
-25
fmoe/layers.py
fmoe/layers.py
+0
-25
fmoe/megatron.py
fmoe/megatron.py
+34
-0
No files found.
fmoe/layers.py
View file @
94eca783
r
'''
Layers that FMoE provides to users
'''
import
math
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
.functions
import
moe_prepare_forward
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
...
...
@@ -31,29 +29,6 @@ class FMoELinear(nn.Module):
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
r
'''
Initialize the weight as linear layers
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
# copied from torch.nn.init.kaiming_uniform_
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
'fan_in'
)
gain
=
nn
.
init
.
calculate_gain
(
'leaky_relu'
,
math
.
sqrt
(
5
))
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
...
...
fmoe/megatron.py
View file @
94eca783
...
...
@@ -6,6 +6,8 @@ See `examples/megatron` for usage instructions.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
import
math
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
...
...
@@ -24,6 +26,26 @@ class _MegatronMLP(nn.Module):
return
x
,
torch
.
zeros_like
(
x
)
def
_random_init_weight
(
self
,
rng
):
r
'''
Copied from torch.nn.init.kaiming_uniform_
'''
fan
=
nn
.
init
.
_calculate_correct_fan
(
self
.
weight
[
0
],
'fan_in'
)
gain
=
nn
.
init
.
calculate_gain
(
'leaky_relu'
,
math
.
sqrt
(
5
))
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
class
MegatronMLP
(
FMoETransformerMLP
):
r
'''
Make the FMoETransformerMLP layer that distributes experts across
...
...
@@ -43,6 +65,18 @@ class MegatronMLP(FMoETransformerMLP):
world_size
=
world_size
,
mp_group
=
group
,
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
self
.
hidden_size
=
args
.
hidden_size
self
.
rank
=
args
.
rank
self
.
reset_parameters
()
def
reset_parameters
(
self
):
r
'''
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_random_init_weight
(
self
.
experts
.
htoh4
,
rng
)
_random_init_weight
(
self
.
experts
.
h4toh
,
rng
)
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
torch
.
zeros
(
self
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment