Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
22e1eb45
Commit
22e1eb45
authored
Feb 01, 2021
by
Rick Ho
Browse files
complete test for reconstruction
parent
d2039fc7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
125 deletions
+16
-125
fmoe/__init__.py
fmoe/__init__.py
+0
-1
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+5
-2
fmoe/layers.py
fmoe/layers.py
+1
-1
fmoe/moe_function.py
fmoe/moe_function.py
+0
-106
tests/moe.py
tests/moe.py
+7
-8
tests/moe_test.py
tests/moe_test.py
+3
-7
No files found.
fmoe/__init__.py
View file @
22e1eb45
from
.moe
import
BruteForceMoE
from
.layers
import
FMoELinear
,
FMoENaiveGate
,
FMoETransformerMLP
from
.layers
import
FMoELinear
,
FMoENaiveGate
,
FMoETransformerMLP
fmoe/fmoe_functions.py
View file @
22e1eb45
...
@@ -3,8 +3,11 @@ from torch.autograd import Function
...
@@ -3,8 +3,11 @@ from torch.autograd import Function
import
fmoe_cuda
import
fmoe_cuda
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
):
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
fmoe_cuda
.
ensure_nccl
(
torch
.
distributed
.
distributed_c10d
.
_default_pg
,
gate
)
if
comm
is
None
:
comm
=
torch
.
distributed
.
distributed_c10d
.
_default_pg
if
world_size
>
1
:
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
_
,
pos
=
torch
.
sort
(
gate
)
...
...
fmoe/layers.py
View file @
22e1eb45
...
@@ -57,7 +57,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
...
@@ -57,7 +57,7 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
class
FMoETransformerMLP
(
nn
.
Module
):
class
FMoETransformerMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
world_size
=
1
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FMoETransformerMLP
,
self
).
__init__
()
super
(
FMoETransformerMLP
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
...
...
fmoe/moe_function.py
deleted
100644 → 0
View file @
d2039fc7
import
torch
from
torch.autograd
import
Function
import
fmoe_cuda
class
MOELocal
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
],
device
=
weight
.
device
,
dtype
=
torch
.
long
)
expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc
=
expert_count
.
cpu
()
input_buf
,
=
fmoe_cuda
.
local_gather
(
inp
,
pos
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
ecc
)
output
=
fmoe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
ecc
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
@
staticmethod
def
backward
(
ctx
,
grad_out
):
input_buf
,
gate
,
weight
,
expert_count
,
pos
=
ctx
.
saved_tensors
grad_out_buf
,
=
fmoe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_out_buf
,
input_buf
,
weight
,
expert_count
)
grad_inp
,
=
fmoe_cuda
.
local_gather
(
grad_inp_buf
,
pos
)
return
grad_inp
,
None
,
grad_weight
class
MOEGlobal
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
fmoe_cuda
.
ensure_nccl
(
torch
.
distributed
.
distributed_c10d
.
_default_pg
,
inp
)
num_expert
=
weight
.
shape
[
0
]
# local_expert_count, pos = fmoe_cuda.expert_count(gate,
# world_size * num_expert)
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
]
*
world_size
,
device
=
weight
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
).
cpu
()
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
fmoe_cuda
.
local_gather
(
inp
,
pos
)
local_expert_count
=
local_expert_count
.
cpu
()
global_expert_count
=
global_expert_count
.
cpu
()
local_output_buf
,
global_input_buf
=
fmoe_cuda
.
global_fused_forward
(
local_input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
)
output
,
=
fmoe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
variables
=
(
global_input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
pos
)
ctx
.
moe_args
=
(
num_expert
,
inp
.
shape
[
0
],
fwd_batch_size
,
world_size
)
ctx
.
save_for_backward
(
*
variables
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
pos
)
=
ctx
.
saved_tensors
num_expert
,
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
,
=
fmoe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
global_grad_out_buf
,
input_buf
,
weight
,
fwd_expert_count
)
local_grad_inp_buf
,
=
fmoe_cuda
.
global_gather
(
grad_inp_buf
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
grad_inp
,
=
fmoe_cuda
.
local_gather
(
local_grad_inp_buf
,
pos
)
return
grad_inp
,
None
,
grad_weight
,
None
def
moe
(
inp
,
gate
,
weight
,
world_size
):
if
world_size
is
not
None
and
world_size
>
1
:
return
MOEGlobal
.
apply
(
inp
,
gate
,
weight
,
world_size
)
else
:
return
MOELocal
.
apply
(
inp
,
gate
,
weight
)
fmoe
/moe.py
→
tests
/moe.py
View file @
22e1eb45
...
@@ -3,28 +3,27 @@ from torch import nn
...
@@ -3,28 +3,27 @@ from torch import nn
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.
moe
_function
import
moe
from
f
moe
.layers
import
FMoELinear
,
_fmoe_full_forward
class
FMoE
(
nn
.
Module
):
class
FMoE
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
world_size
=
None
):
world_size
=
1
):
super
(
FMoE
,
self
).
__init__
()
super
(
FMoE
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
weight
=
nn
.
Parameter
(
self
.
linear
=
FMoELinear
(
num_expert
,
in_feat
,
out_feat
)
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight
=
self
.
linear
.
weight
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
self
.
linear
.
reset_parameters
()
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
,
gate
):
return
moe
(
inp
,
gate
.
int
(),
self
.
weight
,
self
.
world_size
)
return
_fmoe_full_forward
(
inp
,
gate
,
[
self
.
linear
],
None
,
self
.
num_expert
,
self
.
world_size
)
class
BruteForceMoE
(
nn
.
Module
):
class
BruteForceMoE
(
nn
.
Module
):
...
...
tests/moe_test.py
View file @
22e1eb45
from
f
moe
import
FMoE
as
MOELayer
from
moe
import
FMoE
as
MOELayer
from
f
moe
import
BruteForceMoE
as
MOELayer_raw
from
moe
import
BruteForceMoE
as
MOELayer_raw
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
time
import
time
...
@@ -82,7 +82,6 @@ def test_module(moe, linear, inp, gate):
...
@@ -82,7 +82,6 @@ def test_module(moe, linear, inp, gate):
moe
.
zero_grad
()
moe
.
zero_grad
()
x
=
(
linear
(
inp
))
x
=
(
linear
(
inp
))
output
=
moe
(
x
,
gate
)
output
=
moe
(
x
,
gate
)
# print('ooutput', torch.distributed.get_rank(), output)
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
...
@@ -102,10 +101,7 @@ def test():
...
@@ -102,10 +101,7 @@ def test():
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
if
world_size
>
1
:
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
else
:
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment