Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
ba878d29
"src/vscode:/vscode.git/clone" did not exist on "6d1a6486024192951ce696e8f4cf79a39509182f"
Commit
ba878d29
authored
Feb 26, 2021
by
Rick Ho
Browse files
fix lint
parent
66f7166d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
9 deletions
+19
-9
fmoe/distributed.py
fmoe/distributed.py
+1
-3
fmoe/gates.py
fmoe/gates.py
+5
-3
fmoe/layers.py
fmoe/layers.py
+4
-0
fmoe/megatron.py
fmoe/megatron.py
+9
-3
No files found.
fmoe/distributed.py
View file @
ba878d29
...
...
@@ -90,14 +90,12 @@ class DistributedGroupedDataParallel(nn.Module):
groups
[
group_key
]
=
[
p
]
else
:
groups
[
group_key
].
append
(
p
)
for
(
dp_comm
,
dtype
),
group
in
groups
.
items
():
for
(
dp_comm
,
_
),
group
in
groups
.
items
():
if
dp_comm
not
in
self
.
comms
:
continue
comm
=
self
.
comms
[
dp_comm
]
datas
=
[
p
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
datas
)
if
fp32_allreduce
and
dtype
!=
torch
.
float32
:
coalesced
=
coalesced
.
float
()
torch
.
distributed
.
broadcast
(
coalesced
,
0
,
group
=
comm
)
torch
.
cuda
.
synchronize
()
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
...
...
fmoe/gates.py
View file @
ba878d29
...
...
@@ -8,14 +8,16 @@ import torch.nn.functional as F
class
ZeroGate
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
r
'''
Guide all input samples to gate 0.
'''
def
__init__
(
self
,
_1
,
_2
,
_3
,
top_k
=
2
):
super
().
__init__
()
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
r
'''
The naive implementation simply calculates the top-k of a linear layer's
output.
All output to expert 1
'''
idx
=
torch
.
zeros
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
...
...
fmoe/layers.py
View file @
ba878d29
...
...
@@ -150,6 +150,10 @@ class FMoE(nn.Module):
self
.
experts_fused
=
True
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
'''
The default expert function which either calls the experts as a whole
or as separate experts.
'''
if
self
.
experts_fused
:
return
self
.
experts
(
inp
,
fwd_expert_count
)
outputs
=
[]
...
...
fmoe/megatron.py
View file @
ba878d29
...
...
@@ -3,22 +3,28 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
class
_FakeMegatronMLP
(
nn
.
Module
):
r
'''
A fake mlp without model parallelism for correctness testing
'''
def
__init__
(
self
,
args
,
group
):
def
__init__
(
self
,
args
,
_
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_hidden_size
)
self
.
fc2
=
nn
.
Linear
(
args
.
hidden_hidden_size
,
args
.
hidden_size
)
def
forward
(
self
,
x
):
r
'''
Directly use GeLU
'''
x
=
self
.
fc1
(
x
)
x
=
F
.
gelu
(
x
)
x
=
self
.
fc2
(
x
)
...
...
@@ -71,7 +77,7 @@ class MegatronMLP(FMoETransformerMLP):
r
'''
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
additional numpy rng is used.
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_random_init_weight
(
self
.
experts
.
htoh4
,
rng
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment