Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
b5b72d41
Commit
b5b72d41
authored
Mar 30, 2022
by
Rick Ho
Browse files
forward tested
parent
771dc62d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
6 deletions
+95
-6
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+1
-1
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+2
-0
fmoe/fastermoe/expert_utils.py
fmoe/fastermoe/expert_utils.py
+4
-1
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+5
-4
tests/test_faster_shadow.py
tests/test_faster_shadow.py
+83
-0
No files found.
cuda/fastermoe/smart_schedule.h
View file @
b5b72d41
...
@@ -168,7 +168,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -168,7 +168,7 @@ void fmoe_cuda_fused_forward_impl(
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
cudaStreamWaitEvent
(
smgr
->
stream
(
1
),
evt_get
);
}
}
NCCL_SAFE_CALL
(
ncclBcast
(
params
[
si
].
data_ptr
<
void
>
(),
NCCL_SAFE_CALL
(
ncclBcast
(
(
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
expert_size
*
sizeof
(
scalar_t
),
ncclChar
,
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
i
/
num_expert
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
cudaEventCreate
(
evt_shadow
+
si
);
cudaEventCreate
(
evt_shadow
+
si
);
...
...
cuda/fmoe_cuda.cpp
View file @
b5b72d41
...
@@ -80,6 +80,8 @@ torch::Tensor _smart_sch_backward(
...
@@ -80,6 +80,8 @@ torch::Tensor _smart_sch_backward(
long
expert_size
,
long
expert_size
,
long
n_workers
,
long
n_workers
,
py
::
function
backward_fn
,
py
::
function
backward_fn
,
py
::
function
stash_fn
,
py
::
function
pop_fn
,
py
::
function
collect_fn
,
py
::
function
collect_fn
,
py
::
function
set_grad_fn
);
py
::
function
set_grad_fn
);
...
...
fmoe/fastermoe/expert_utils.py
View file @
b5b72d41
...
@@ -6,11 +6,12 @@ def get_expert_param_size(e):
...
@@ -6,11 +6,12 @@ def get_expert_param_size(e):
def
get_expert_params
(
e
,
out
):
def
get_expert_params
(
e
,
out
):
print
(
'gep to {}'
.
format
(
out
))
offset
=
0
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
offset
+=
p
.
numel
()
seg
.
copy_
(
p
)
seg
.
copy_
(
p
.
data
.
flatten
()
)
def
stash_expert_params
(
e
,
params
):
def
stash_expert_params
(
e
,
params
):
...
@@ -27,6 +28,8 @@ def stash_expert_params(e, params):
...
@@ -27,6 +28,8 @@ def stash_expert_params(e, params):
def
pop_expert_params
(
e
):
def
pop_expert_params
(
e
):
if
not
hasattr
(
e
,
'expert_param_stash'
):
return
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
p
.
copy_
(
e
.
expert_param_stash
[
n
])
p
.
copy_
(
e
.
expert_param_stash
[
n
])
...
...
fmoe/fastermoe/schedule.py
View file @
b5b72d41
...
@@ -7,7 +7,7 @@ from torch.autograd.function import Function
...
@@ -7,7 +7,7 @@ from torch.autograd.function import Function
from
fmoe.functions
import
prepare_forward
,
ensure_comm
from
fmoe.functions
import
prepare_forward
,
ensure_comm
from
fmoe.functions
import
_local_scatter
,
_local_gather
from
fmoe.functions
import
_local_scatter
,
_local_gather
import
fmoe_cuda
as
fmoe_native
import
fmoe_cuda
as
fmoe_native
import
expert_utils
from
fmoe.fastermoe
import
expert_utils
class
MoEForward
(
Function
):
class
MoEForward
(
Function
):
...
@@ -47,7 +47,7 @@ class MoEForward(Function):
...
@@ -47,7 +47,7 @@ class MoEForward(Function):
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
ctx
.
shadows
=
[
None
]
*
world_size
ctx
.
shadows
=
[
None
]
*
world_size
def
stash_fn
(
params
,
idx
):
def
stash_fn
(
params
,
idx
):
expert_utils
.
stash_expert_params
(
experts
,
p
)
expert_utils
.
stash_expert_params
(
experts
,
p
arams
)
ctx
.
shadows
[
idx
]
=
params
ctx
.
shadows
[
idx
]
=
params
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
...
@@ -99,7 +99,7 @@ class MoEForward(Function):
...
@@ -99,7 +99,7 @@ class MoEForward(Function):
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
# TODO: Using multiple tensors as input is to be supported.
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
# TODO: Support many experts on each process
# TODO: Support many experts on each process
...
@@ -113,7 +113,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
...
@@ -113,7 +113,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
)
=
prepare_forward
(
gate
,
n_expert
,
world_size
)
)
=
prepare_forward
(
gate
,
n_expert
,
world_size
)
# TODO: Expert shadowing is to be supported. Currently using all 0s
# TODO: Expert shadowing is to be supported. Currently using all 0s
stored_models
=
torch
.
zeros
(
n_expert
*
world_size
,
dtype
=
torch
.
bool
)
if
stored_models
is
None
:
stored_models
=
torch
.
zeros
(
n_expert
*
world_size
,
dtype
=
torch
.
bool
)
topk
=
1
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
if
len
(
gate
.
shape
)
==
2
:
...
...
tests/test_faster_shadow.py
0 → 100644
View file @
b5b72d41
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
test_ddp
import
_ensure_initialized
,
_run_distributed
from
test_numerical
import
_assert_numerical
from
fmoe.fastermoe.schedule
import
_fmoe_general_global_forward
as
smart_fwd
from
fmoe.layers
import
_fmoe_general_global_forward
as
naive_fwd
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
def
test_faster_shadow
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
):
_run_distributed
(
'_test_faster_shadow'
,
n_process
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
},
script
=
__file__
,
env
=
dict
(
FMOE_FASTER_GROUP_SIZE
=
str
(
group_sz
)
)
)
def
_test_faster_shadow
(
d_model
,
batch_size
,
n_expert
):
_ensure_initialized
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
x1
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x1
.
requires_grad
=
True
x2
=
x1
.
data
.
clone
()
x2
.
requires_grad
=
True
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
m1
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
m2
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
with
torch
.
no_grad
():
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
bias
.
copy_
(
m1
.
bias
)
def
ef1
(
x
,
fec
):
y
=
m1
(
x
)
return
y
def
ef2
(
x
,
fec
):
y
=
m2
(
x
)
return
y
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
,)).
bool
().
cuda
()
dist
.
broadcast
(
stored_models
,
0
)
stored_models
=
stored_models
.
cpu
()
ensure_comm
(
x1
,
None
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
stored_models
=
stored_models
)
# y1.sum().backward()
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
# y2.sum().backward()
_assert_numerical
([
'out'
],
[
y1
],
[
y2
],
rank
)
# _assert_numerical(['out', 'grad_in', 'grad_bias', 'grad_weight'],
# [y1, x1.grad, m1.bias.grad, m1.weight.grad],
# [y2, x2.grad, m2.bias.grad, m2.weight.grad], rank)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
# test_faster_shadow(8, 16, 16, 1, 2)
_test_faster_shadow
(
4
,
2
,
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment