Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
91a5e794
Commit
91a5e794
authored
Apr 01, 2022
by
Rick Ho
Browse files
faster policies
parent
794dd0e6
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
14 deletions
+131
-14
fmoe/fastermoe/config.py
fmoe/fastermoe/config.py
+13
-0
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+12
-2
fmoe/fastermoe/shadow_policy.py
fmoe/fastermoe/shadow_policy.py
+73
-0
fmoe/functions.py
fmoe/functions.py
+9
-0
fmoe/layers.py
fmoe/layers.py
+5
-1
tests/test_ddp.py
tests/test_ddp.py
+1
-1
tests/test_faster_shadow.py
tests/test_faster_shadow.py
+18
-10
No files found.
fmoe/fastermoe/config.py
0 → 100644
View file @
91a5e794
import
os
def
float_from_env
(
key
,
default
=-
1
):
if
key
in
os
.
environ
:
return
float
(
os
.
environ
[
key
])
return
default
def
switch_from_env
(
key
,
default
=
False
):
if
key
in
os
.
environ
:
return
os
.
environ
[
key
]
in
[
'1'
,
'ON'
]
return
default
fmoe/fastermoe/schedule.py
View file @
91a5e794
...
@@ -9,6 +9,8 @@ from fmoe.functions import _local_scatter, _local_gather
...
@@ -9,6 +9,8 @@ from fmoe.functions import _local_scatter, _local_gather
import
fmoe_cuda
as
fmoe_native
import
fmoe_cuda
as
fmoe_native
from
fmoe.fastermoe
import
expert_utils
from
fmoe.fastermoe
import
expert_utils
from
.shadow_policy
import
get_shadow_policy
class
MoEForward
(
Function
):
class
MoEForward
(
Function
):
@
staticmethod
@
staticmethod
...
@@ -31,6 +33,7 @@ class MoEForward(Function):
...
@@ -31,6 +33,7 @@ class MoEForward(Function):
x
=
x
.
data
x
=
x
.
data
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
x
.
requires_grad
=
True
x
.
requires_grad
=
True
# To skip torch autograd's version check.
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
nothing
,
nothing
):
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
nothing
,
nothing
):
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
y0
=
expert_fn
(
x
,
[
x
.
shape
[
0
]])
ctx
.
gibs
[
idx
]
=
x
ctx
.
gibs
[
idx
]
=
x
...
@@ -101,6 +104,9 @@ class MoEForward(Function):
...
@@ -101,6 +104,9 @@ class MoEForward(Function):
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
policy_fn
=
None
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
# TODO: Using multiple tensors as input is to be supported.
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
...
@@ -114,9 +120,13 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
...
@@ -114,9 +120,13 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
fwd_batch_size
,
fwd_batch_size
,
)
=
prepare_forward
(
gate
,
n_expert
,
world_size
)
)
=
prepare_forward
(
gate
,
n_expert
,
world_size
)
# TODO: Expert shadowing is to be supported. Currently using all 0s
global
policy_fn
if
policy_fn
is
None
:
policy_fn
=
get_shadow_policy
(
d_model
=
inp
.
shape
[
-
1
])
if
stored_models
is
None
:
if
stored_models
is
None
:
stored_models
=
torch
.
zeros
(
n_expert
*
world_size
,
dtype
=
torch
.
bool
)
stored_models
=
policy_fn
(
local_expert_count
,
global_expert_count
,
n_expert
,
world_size
)
topk
=
1
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
if
len
(
gate
.
shape
)
==
2
:
...
...
fmoe/fastermoe/shadow_policy.py
0 → 100644
View file @
91a5e794
import
os
import
torch
import
torch.distributed
as
dist
from
.config
import
float_from_env
,
switch_from_env
from
fmoe.functions
import
get_moe_group
def
global_policy
(
local_expert_count
,
_gec
,
num_expert
,
world_size
):
r
"""
This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
A few parameters are used in this policy.
* `d_model`: feature length of the MLP input and output.
* `alpha`: the ratio of the MLP's hidden size to `d_model`.
* `bw_net`: bandwidth of the network (GBps)
* `bw_mm`: computation throughput of performing GeMM (FLOPs)
"""
bw_net
=
float_from_env
(
'FMOE_FASTER_GLBPLC_NETBW'
,
50
*
1e9
/
8
)
bw_mm
=
float_from_env
(
'FMOE_FASTER_GLBPLC_GPUTP'
,
11.5e12
)
alpha
=
float_from_env
(
'FMOE_FASTER_GLBPLC_ALPHA'
,
2
)
d_model
=
float_from_env
(
'FMOE_FASTER_GLBPLC_DMODEL'
,
2048
)
moe_group
=
get_moe_group
()
local_expert_count
=
local_expert_count
.
cuda
()
agecs
=
[
torch
.
empty_like
(
local_expert_count
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
agecs
,
local_expert_count
,
group
=
moe_group
)
all_global_expert_count
=
torch
.
stack
(
agecs
)
# TODO: data type other than float
data_size
=
4
fwd_expert_counts
=
all_global_expert_count
.
sum
(
1
).
cpu
()
B_ws
,
indices
=
fwd_expert_counts
.
flatten
().
sort
(
0
,
descending
=
True
)
alphaH2
=
alpha
*
(
d_model
**
2
)
B_w
=
B_ws
[
0
]
comm
=
float
(
'+inf'
)
send_feature_time
=
d_model
*
data_size
/
bw_net
send_model_time
=
2
*
alphaH2
*
data_size
/
bw_net
comp_time
=
4
*
alphaH2
/
bw_mm
lat_base
=
3
*
comp_time
*
B_w
+
4
*
send_feature_time
*
B_w
res
=
torch
.
zeros
(
world_size
*
num_expert
,
dtype
=
torch
.
bool
)
shadow_time
=
0
for
i
,
index
in
enumerate
(
indices
):
if
i
+
1
==
indices
.
numel
():
break
B_k
=
B_ws
[
i
+
1
]
shadow_time
+=
send_model_time
lat_new
=
3
*
comp_time
*
B_k
+
4
*
send_feature_time
*
B_k
+
shadow_time
if
lat_new
<
lat_base
:
lat_base
=
lat_new
res
[
index
]
=
True
else
:
break
return
res
def
no_shadow_policy
(
_lec
,
_gec
,
num_expert
,
world_size
):
res
=
torch
.
zeros
(
world_size
*
num_expert
,
dtype
=
bool
)
return
res
def
get_shadow_policy
(
d_model
=
None
):
if
d_model
is
not
None
and
'FMOE_FASTER_GLBPLC_DMODEL'
not
in
os
.
environ
:
os
.
environ
[
'FMOE_FASTER_GLBPLC_DMODEL'
]
=
str
(
d_model
)
if
not
switch_from_env
(
'FMOE_FASTER_SHADOW_ENABLE'
):
return
no_policy
return
global_policy
fmoe/functions.py
View file @
91a5e794
...
@@ -10,12 +10,21 @@ import fmoe_cuda
...
@@ -10,12 +10,21 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
_moe_group
=
None
def
ensure_comm
(
t
,
comm
):
def
ensure_comm
(
t
,
comm
):
if
comm
is
None
:
if
comm
is
None
:
comm
=
get_torch_default_comm
()
comm
=
get_torch_default_comm
()
global
_moe_group
_moe_group
=
comm
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
get_moe_group
():
return
_moe_group
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
local_expert_count
=
torch
.
zeros
(
local_expert_count
=
torch
.
zeros
(
...
...
fmoe/layers.py
View file @
91a5e794
...
@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather
...
@@ -11,6 +11,8 @@ from .functions import MOEScatter, MOEGather
from
.functions
import
AllGather
,
Slice
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
from
.gates
import
NaiveGate
from
.fastermoe.config
import
switch_from_env
def
mark_module_parallel_comm
(
module
,
comm
):
def
mark_module_parallel_comm
(
module
,
comm
):
r
"""
r
"""
...
@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, *
...
@@ -76,7 +78,9 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size, *
return
outp
return
outp
if
os
.
environ
.
get
(
'FMOE_FASTER_SCHEDULE_ENABLE'
,
'0'
)
in
[
'1'
,
'ON'
]:
fmoe_faster_schedule
=
False
if
switch_from_env
(
'FMOE_FASTER_SCHEDULE_ENABLE'
,
False
):
fmoe_faster_schedule
=
True
from
.fastermoe.schedule
import
_fmoe_general_global_forward
from
.fastermoe.schedule
import
_fmoe_general_global_forward
...
...
tests/test_ddp.py
View file @
91a5e794
...
@@ -3,6 +3,7 @@ import random
...
@@ -3,6 +3,7 @@ import random
import
os
import
os
import
sys
import
sys
from
typing
import
Dict
from
typing
import
Dict
import
random
import
pytest
import
pytest
import
torch
import
torch
...
@@ -19,7 +20,6 @@ def _ensure_initialized():
...
@@ -19,7 +20,6 @@ def _ensure_initialized():
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
if
not
dist
.
is_initialized
():
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
...
...
tests/test_faster_shadow.py
View file @
91a5e794
...
@@ -17,25 +17,28 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
...
@@ -17,25 +17,28 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
512
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
def
test_faster_shadow
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
):
@
pytest
.
mark
.
parametrize
(
"pass_stored"
,
[
False
,
True
])
def
test_faster_shadow
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
,
pass_stored
):
_run_distributed
(
'_test_faster_shadow'
,
_run_distributed
(
'_test_faster_shadow'
,
n_process
,
n_process
,
{
{
'd_model'
:
d_model
,
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
'n_expert'
:
n_expert
,
'pass_stored'
:
pass_stored
},
},
script
=
__file__
,
script
=
__file__
,
env
=
dict
(
env
=
dict
(
FMOE_FASTER_GROUP_SIZE
=
str
(
group_sz
)
FMOE_FASTER_GROUP_SIZE
=
str
(
group_sz
),
FMOE_FASTER_SHADOW_ENABLE
=
'ON'
)
)
)
)
def
_test_faster_shadow
(
d_model
,
batch_size
,
n_expert
):
def
_test_faster_shadow
(
d_model
,
batch_size
,
n_expert
,
pass_stored
):
_ensure_initialized
()
_ensure_initialized
()
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
...
@@ -58,6 +61,7 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
...
@@ -58,6 +61,7 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
y
=
m2
(
x
)
y
=
m2
(
x
)
return
y
return
y
if
pass_stored
:
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
,)).
bool
().
cuda
()
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
,)).
bool
().
cuda
()
dist
.
broadcast
(
stored_models
,
0
)
dist
.
broadcast
(
stored_models
,
0
)
stored_models
=
stored_models
.
cpu
()
stored_models
=
stored_models
.
cpu
()
...
@@ -66,7 +70,11 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
...
@@ -66,7 +70,11 @@ def _test_faster_shadow(d_model, batch_size, n_expert):
# print('stored models {}'.format(stored_models))
# print('stored models {}'.format(stored_models))
ensure_comm
(
x1
,
None
)
ensure_comm
(
x1
,
None
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
stored_models
=
stored_models
)
if
pass_stored
:
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
stored_models
=
stored_models
)
else
:
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
)
y1
.
sum
().
backward
()
y1
.
sum
().
backward
()
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
...
@@ -82,4 +90,4 @@ if __name__ == '__main__':
...
@@ -82,4 +90,4 @@ if __name__ == '__main__':
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
# test_faster_shadow(8, 16, 16, 1, 2)
# test_faster_shadow(8, 16, 16, 1, 2)
_test_faster_shadow
(
4
,
2
,
1
)
_test_faster_shadow
(
102
4
,
16
,
1
,
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment