Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6900f1de
Commit
6900f1de
authored
Jan 29, 2021
by
Rick Ho
Browse files
fix python bugs
parent
437afda2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
72 deletions
+20
-72
fmoe/__init__.py
fmoe/__init__.py
+1
-1
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+10
-10
fmoe/layers.py
fmoe/layers.py
+8
-5
fmoe/megatron.py
fmoe/megatron.py
+1
-56
No files found.
fmoe/__init__.py
View file @
6900f1de
from
.moe
import
BruteForceMoE
from
.
fmoe
import
FMoELinear
,
FMoENaiveGate
,
FMoETransformerMLP
from
.
layers
import
FMoELinear
,
FMoENaiveGate
,
FMoETransformerMLP
fmoe/fmoe_functions.py
View file @
6900f1de
...
...
@@ -9,8 +9,8 @@ def moe_prepare_forward(gate, num_expert, world_size):
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
]
*
world_size
,
device
=
weight
.
device
,
dtype
=
torch
.
long
)
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
...
...
@@ -28,7 +28,7 @@ class MOEScatter(Function):
fwd_batch_size
,
world_size
):
local_input_buf
,
=
fmoe_cuda
.
local_gather
(
inp
,
pos
)
if
world_size
>
1
:
global_input_buf
,
=
moe_cuda
.
global_scatter
(
local_input_buf
,
global_input_buf
,
=
f
moe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
else
:
...
...
@@ -43,19 +43,19 @@ class MOEScatter(Function):
(
fwd_batch_size
,
local_batch_size
,
world_size
)
=
ctx
.
moe_args
if
world_size
>
1
:
local_grad_in
,
=
moe_cuda
.
global_gather
(
global_grad_out
,
local_grad_in
,
=
f
moe_cuda
.
global_gather
(
global_grad_out
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
else
:
local_grad_in
=
global_grad_in
grad_in
,
=
moe_cuda
.
local_scatter
(
local_grad_in
,
pos
)
grad_in
,
=
f
moe_cuda
.
local_scatter
(
local_grad_in
,
pos
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
class
MOELinear
(
Function
):
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
global_output_buf
,
=
moe_cuda
.
forward
(
global_input_buf
,
weight
,
global_output_buf
,
=
f
moe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
input_buf
,
weight
,
fwd_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
...
...
@@ -74,12 +74,12 @@ class MOEGather(Function):
def
forward
(
ctx
,
global_output_buf
,
pos
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
):
if
world_size
>
1
:
local_output_buf
,
=
moe_cuda
.
global_gather
(
global_output_buf
,
local_output_buf
,
=
f
moe_cuda
.
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
else
:
local_output_buf
=
global_output_buf
output
,
=
moe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
output
,
=
f
moe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
fwd_batch_size
,
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
...
...
@@ -90,9 +90,9 @@ class MOEGather(Function):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
=
moe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
=
f
moe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
global_grad_out_buf
,
=
moe_cuda
.
global_scatter
(
grad_out_buf
,
global_grad_out_buf
,
=
f
moe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
else
:
...
...
fmoe/
fmoe
.py
→
fmoe/
layers
.py
View file @
6900f1de
from
.fmoe_functions
import
*
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
FMoELinear
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
FMoE
,
self
).
__init__
()
super
(
FMoE
Linear
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
...
...
@@ -21,10 +22,11 @@ class FMoELinear(nn.Module):
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
class
FMoENaiveGate
(
nn
.
m
odule
):
def
__init__
(
self
,
num_expert
=
32
,
world_size
=
1
,
top_k
=
2
):
class
FMoENaiveGate
(
nn
.
M
odule
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
(
FMoENaiveGate
,
self
).
__init__
()
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
gate
=
self
.
gate
(
inp
)
...
...
@@ -53,7 +55,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
return
x
class
FMoETransformerMLP
(
nn
.
m
odule
):
class
FMoETransformerMLP
(
nn
.
M
odule
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
...
...
@@ -64,11 +66,12 @@ class FMoETransformerMLP(nn.module):
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
pre_lnorm
=
pre_lnorm
self
.
top_k
=
top_k
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
gate
=
FMoENaive
g
ate
(
num_expert
,
world_size
,
top_k
)
self
.
gate
=
FMoENaive
G
ate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
...
...
fmoe/megatron.py
View file @
6900f1de
from
torch
import
nn
from
.moe
import
FMoE
from
.moe_function
import
moe
from
.fmoe
import
FMoETransformerMLP
class
FFFN
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FFFN
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
top_k
=
top_k
self
.
pre_lnorm
=
pre_lnorm
self
.
htoh4
=
FMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
=
world_size
)
self
.
h4toh
=
FMoE
(
num_expert
,
d_hidden
,
d_model
,
world_size
=
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
self
.
htoh4
(
inp
,
gate_top_k_idx
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
gate_top_k_idx
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
from
.layers
import
FMoETransformerMLP
def
create_moe_mlp
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment