Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3b82e379
Commit
3b82e379
authored
Mar 13, 2021
by
Rick Ho
Browse files
fix lint
parent
43af1522
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
11 deletions
+19
-11
fmoe/distributed.py
fmoe/distributed.py
+2
-1
fmoe/functions.py
fmoe/functions.py
+2
-1
fmoe/gates.py
fmoe/gates.py
+2
-1
fmoe/layers.py
fmoe/layers.py
+4
-2
fmoe/megatron.py
fmoe/megatron.py
+5
-4
fmoe/transformer.py
fmoe/transformer.py
+4
-2
No files found.
fmoe/distributed.py
View file @
3b82e379
...
@@ -47,7 +47,8 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -47,7 +47,8 @@ class DistributedGroupedDataParallel(nn.Module):
else
:
else
:
self
.
comms
[
"world"
]
=
world_group
self
.
comms
[
"world"
]
=
world_group
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
groups
=
dict
()
groups
=
dict
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
...
...
fmoe/functions.py
View file @
3b82e379
...
@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -40,7 +40,8 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
)
)
else
:
else
:
global_expert_count
=
local_expert_count
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
return
(
pos
,
pos
,
...
...
fmoe/gates.py
View file @
3b82e379
...
@@ -23,7 +23,8 @@ class ZeroGate(nn.Module):
...
@@ -23,7 +23,8 @@ class ZeroGate(nn.Module):
idx
=
torch
.
zeros
(
idx
=
torch
.
zeros
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
)
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
score
=
torch
.
ones
(
inp
.
shape
[
0
]
*
self
.
top_k
,
device
=
inp
.
device
)
/
self
.
top_k
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
return
idx
,
score
.
reshape
(
-
1
,
1
,
self
.
top_k
)
...
...
fmoe/layers.py
View file @
3b82e379
...
@@ -114,7 +114,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -114,7 +114,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size
,
fwd_batch_size
,
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
=
MOEGather
.
apply
(
...
@@ -165,7 +166,8 @@ class FMoE(nn.Module):
...
@@ -165,7 +166,8 @@ class FMoE(nn.Module):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert
is
not
None
:
if
expert
is
not
None
:
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts_fused
=
False
self
.
experts_fused
=
False
else
:
else
:
self
.
experts_fused
=
True
self
.
experts_fused
=
True
...
...
fmoe/megatron.py
View file @
3b82e379
...
@@ -41,7 +41,7 @@ def _megatron_init_method(self, rng, sigma):
...
@@ -41,7 +41,7 @@ def _megatron_init_method(self, rng, sigma):
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
normal
(
loc
=
0.0
,
scale
=
sigma
,
size
=
tuple
(
self
.
weight
.
size
()))
weight
=
rng
.
normal
(
loc
=
0.0
,
scale
=
sigma
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
self
.
weight
.
data
=
torch
.
from_numpy
(
weight
).
to
(
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
# Always initialize bias to zero.
# Always initialize bias to zero.
...
@@ -60,13 +60,13 @@ def _random_init_weight(self, rng):
...
@@ -60,13 +60,13 @@ def _random_init_weight(self, rng):
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
self
.
weight
.
data
=
torch
.
from_numpy
(
weight
).
to
(
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
self
.
bias
.
data
=
torch
.
from_numpy
(
bias
).
to
(
dtype
=
dtype
,
device
=
device
)
class
MegatronMLP
(
FMoETransformerMLP
):
class
MegatronMLP
(
FMoETransformerMLP
):
...
@@ -77,7 +77,8 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -77,7 +77,8 @@ class MegatronMLP(FMoETransformerMLP):
def
__init__
(
self
,
args
,
group
):
def
__init__
(
self
,
args
,
group
):
assert
(
assert
(
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
args
.
seq_length
*
args
.
micro_batch_size
%
args
.
tensor_model_parallel_size
==
0
==
0
),
"Batch size x sequence length should be multiple of mp size"
),
"Batch size x sequence length should be multiple of mp size"
if
not
args
.
distributed_experts
:
if
not
args
.
distributed_experts
:
...
...
fmoe/transformer.py
View file @
3b82e379
...
@@ -15,8 +15,10 @@ class _Expert(nn.Module):
...
@@ -15,8 +15,10 @@ class _Expert(nn.Module):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
rank
=
rank
)
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment