Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0fea2991
"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "c67425b070ec8fc2d6e4757cd74cf4b171d48902"
Commit
0fea2991
authored
Jan 29, 2021
by
Rick Ho
Browse files
fix more bugs to make the layers run in the model
parent
6900f1de
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
13 deletions
+16
-13
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+12
-8
fmoe/layers.py
fmoe/layers.py
+4
-5
No files found.
fmoe/fmoe_functions.py
View file @
0fea2991
...
@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
...
@@ -13,10 +13,13 @@ def moe_prepare_forward(gate, num_expert, world_size):
device
=
gate
.
device
,
dtype
=
torch
.
long
)
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
if
world_size
>
1
:
local_expert_count
,
num_expert
,
world_size
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
else
:
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
.
cpu
()
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
...
@@ -35,6 +38,7 @@ class MOEScatter(Function):
...
@@ -35,6 +38,7 @@ class MOEScatter(Function):
global_input_buf
=
local_input_buf
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
global_input_buf
return
global_input_buf
@
staticmethod
@
staticmethod
...
@@ -57,14 +61,14 @@ class MOELinear(Function):
...
@@ -57,14 +61,14 @@ class MOELinear(Function):
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
global_output_buf
,
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
global_output_buf
,
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
fwd_expert_count
)
variables
=
(
input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_
input_buf
,
weight
,
fwd_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
return
global_output_buf
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
o
m
e_cuda
.
backward
(
grad_inp_buf
,
grad_weight
=
fm
oe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
return
grad_inp_buf
,
grad_weight
,
None
return
grad_inp_buf
,
grad_weight
,
None
...
@@ -81,7 +85,7 @@ class MOEGather(Function):
...
@@ -81,7 +85,7 @@ class MOEGather(Function):
local_output_buf
=
global_output_buf
local_output_buf
=
global_output_buf
output
,
=
fmoe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
output
,
=
fmoe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
fwd
_batch_size
,
world_size
ctx
.
moe_args
=
local
_batch_size
,
global_output_buf
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
return
output
...
@@ -89,8 +93,8 @@ class MOEGather(Function):
...
@@ -89,8 +93,8 @@ class MOEGather(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
fwd_batch_size
,
world_size
=
ctx
.
moe_args
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
=
fmoe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
,
=
fmoe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
...
...
fmoe/layers.py
View file @
0fea2991
...
@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
...
@@ -49,7 +49,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
for
i
,
l
in
enumerate
(
linears
):
for
i
,
l
in
enumerate
(
linears
):
if
i
:
if
i
:
x
=
activation
(
x
)
x
=
activation
(
x
)
x
=
l
(
x
)
x
=
l
(
x
,
fwd_expert_count
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
inp
.
shape
[
0
],
world_size
)
return
x
return
x
...
@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
...
@@ -78,16 +78,15 @@ class FMoETransformerMLP(nn.Module):
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
residual
=
inp
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
self
.
layer_norm
(
inp
)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# TODO: merge replication into local_scatter
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
)
self
.
num_expert
,
self
.
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment