Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
fdbac1df
Commit
fdbac1df
authored
Feb 02, 2021
by
Sengxian
Browse files
Format using black and add model_parallel_rank
parent
ae658b89
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
69 deletions
+131
-69
fmoe/functions.py
fmoe/functions.py
+70
-34
fmoe/layers.py
fmoe/layers.py
+50
-28
fmoe/megatron.py
fmoe/megatron.py
+11
-7
No files found.
fmoe/functions.py
View file @
fdbac1df
...
...
@@ -12,31 +12,48 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
if
world_size
>
1
:
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
(
global_expert_count
,)
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
else
:
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
fwd_expert_count
.
cpu
(),
fwd_batch_size
,
)
class
MOEScatter
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
):
local_input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
):
(
local_input_buf
,)
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
if
world_size
>
1
:
global_input_buf
,
=
fmoe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
(
global_input_buf
,)
=
fmoe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
)
else
:
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
...
...
@@ -50,20 +67,25 @@ class MOEScatter(Function):
(
fwd_batch_size
,
local_batch_size
,
world_size
)
=
ctx
.
moe_args
if
world_size
>
1
:
local_grad_in
,
=
fmoe_cuda
.
global_gather
(
global_grad_in
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
(
local_grad_in
,)
=
fmoe_cuda
.
global_gather
(
global_grad_in
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
,
)
else
:
local_grad_in
=
global_grad_in
grad_in
,
=
fmoe_cuda
.
local_gather
(
local_grad_in
,
pos
)
(
grad_in
,
)
=
fmoe_cuda
.
local_gather
(
local_grad_in
,
pos
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
class
MOELinear
(
Function
):
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
global_output_buf
,
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_input_buf
,
weight
,
fwd_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
...
...
@@ -72,21 +94,33 @@ class MOELinear(Function):
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
return
grad_inp_buf
,
grad_weight
,
None
class
MOEGather
(
Function
):
@
staticmethod
def
forward
(
ctx
,
global_output_buf
,
pos
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
):
def
forward
(
ctx
,
global_output_buf
,
pos
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
,
):
if
world_size
>
1
:
local_output_buf
,
=
fmoe_cuda
.
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
(
local_output_buf
,)
=
fmoe_cuda
.
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
,
)
else
:
local_output_buf
=
global_output_buf
output
,
=
fmoe_cuda
.
local_gather
(
local_output_buf
,
pos
)
(
output
,
)
=
fmoe_cuda
.
local_gather
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
local_batch_size
,
global_output_buf
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
...
...
@@ -97,13 +131,15 @@ class MOEGather(Function):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
,
=
fmoe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
(
grad_out_buf
,
)
=
fmoe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
(
global_grad_out_buf
,)
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
)
else
:
global_grad_out_buf
=
grad_out_buf
return
global_grad_out_buf
,
None
,
None
,
None
,
None
,
None
fmoe/layers.py
View file @
fdbac1df
...
...
@@ -9,8 +9,7 @@ class FMoELinear(nn.Module):
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -30,35 +29,51 @@ class FMoENaiveGate(nn.Module):
def
forward
(
self
,
inp
):
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate_score
def
_fmoe_full_forward
(
inp
,
gate
,
linears
,
activation
,
num_expert
,
world_size
):
(
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
(
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
,
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
for
i
,
l
in
enumerate
(
linears
):
if
i
:
x
=
activation
(
x
)
x
=
l
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
return
x
class
FMoETransformerMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
,
model_parallel_rank
=-
1
,
):
super
(
FMoETransformerMLP
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
...
...
@@ -69,13 +84,15 @@ class FMoETransformerMLP(nn.Module):
self
.
top_k
=
top_k
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
gate
=
FMoENaiveGate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
)
)
self
.
model_parallel_rank
=
model_parallel_rank
def
forward
(
self
,
inp
):
residual
=
inp
...
...
@@ -85,18 +102,23 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# TODO: merge replication into local_scatter
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
,
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
fmoe/megatron.py
View file @
fdbac1df
...
...
@@ -2,11 +2,15 @@ from .layers import FMoETransformerMLP
def
create_moe_mlp
(
args
):
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
FMoETransformerMLP
(
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
args
.
model_parallel_size
)
assert
(
args
.
num_experts
%
args
.
model_parallel_size
==
0
),
"Num experts should be multiple of mp size"
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
FMoETransformerMLP
(
num_experts
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
args
.
model_parallel_size
,
model_parallel_rank
=
args
.
model_parallel_rank
,
)
return
fmoe
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment