Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0fea2991
Commit
0fea2991
authored
Jan 29, 2021
by
Rick Ho
Browse files
fix more bugs to make the layers run in the model
parent
6900f1de
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
13 deletions
+16
-13
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+12
-8
fmoe/layers.py
fmoe/layers.py
+4
-5
No files found.
fmoe/fmoe_functions.py
View file @
0fea2991
...
@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
...
@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
device
=
gate
.
device
,
dtype
=
torch
.
long
)
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
if
world_size
>
1
:
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
local_expert_count
,
num_expert
,
world_size
)
else
:
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
.
cpu
()
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
...
@@ -35,6 +38,7 @@ class MOEScatter(Function):
...
@@ -35,6 +38,7 @@ class MOEScatter(Function):
global_input_buf
=
local_input_buf
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
global_input_buf
return
global_input_buf
@
staticmethod
@
staticmethod
...
@@ -57,14 +61,14 @@ class MOELinear(Function):
...
@@ -57,14 +61,14 @@ class MOELinear(Function):
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
global_output_buf
,
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
global_output_buf
,
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
fwd_expert_count
)
variables
=
(
input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_
input_buf
,
weight
,
fwd_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
return
global_output_buf
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
o
m
e_cuda
.
backward
(
grad_inp_buf
,
grad_weight
=
fm
oe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
return
grad_inp_buf
,
grad_weight
,
None
return
grad_inp_buf
,
grad_weight
,
None
...
@@ -81,7 +85,7 @@ class MOEGather(Function):
...
@@ -81,7 +85,7 @@ class MOEGather(Function):
local_output_buf
=
global_output_buf
local_output_buf
=
global_output_buf
output
,
=
fmoe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
output
,
=
fmoe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
fwd
_batch_size
,
world_size
ctx
.
moe_args
=
local
_batch_size
,
global_output_buf
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
return
output
...
@@ -89,8 +93,8 @@ class MOEGather(Function):
...
@@ -89,8 +93,8 @@ class MOEGather(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
fwd_batch_size
,
world_size
=
ctx
.
moe_args
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
=
fmoe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
,
=
fmoe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
...
...
fmoe/layers.py
View file @
0fea2991
...
@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
...
@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
for
i
,
l
in
enumerate
(
linears
):
for
i
,
l
in
enumerate
(
linears
):
if
i
:
if
i
:
x
=
activation
(
x
)
x
=
activation
(
x
)
x
=
l
(
x
)
x
=
l
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
inp
.
shape
[
0
],
world_size
)
return
x
return
x
...
@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
...
@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
residual
=
inp
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
self
.
layer_norm
(
inp
)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# TODO: merge replication into local_scatter
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
)
self
.
num_expert
,
self
.
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment