Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
14878015
Commit
14878015
authored
Mar 21, 2023
by
zms1999
Browse files
support n_expert > 1 for FasterMoE smart scheduling and expert shadowing
parent
698a12ae
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
94 additions
and
60 deletions
+94
-60
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+1
-1
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+14
-14
fmoe/fastermoe/expert_utils.py
fmoe/fastermoe/expert_utils.py
+18
-6
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+30
-29
fmoe/layers.py
fmoe/layers.py
+12
-4
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+13
-2
fmoe/transformer.py
fmoe/transformer.py
+6
-4
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
14878015
...
...
@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
if
(
stored_models_
[
i
])
{
torch
::
Tensor
t
=
input_buf
.
new_empty
({
expert_size
});
if
(
i
/
num_expert
==
rank
)
{
get_param_fn
(
t
);
get_param_fn
(
t
,
i
%
num_expert
);
}
params
.
push_back
(
t
);
}
...
...
cuda/fastermoe/smart_schedule.h
View file @
14878015
...
...
@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
template
<
typename
scalar_t
>
void
computeFn
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
long
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
long
expert_idx
,
long
store_
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
CudaStreamManager
*
smgr
)
{
if
(
micro_batch_size
==
0
)
{
return
;
...
...
@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device,
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
smgr
->
use_default
=
true
;
fn
(
inp
,
oup
,
idx
);
fn
(
inp
,
oup
,
expert_idx
,
store_
idx
);
smgr
->
use_default
=
false
;
}
...
...
@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_stream
);
FMOE_SWE
(
smgr
->
stream
(
1
),
evt_get
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_get
);
cudaEventDestroy
(
evt_get
);
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
...
...
@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl(
(
from_base
+
pipeline_gran
)]
-
offset
;
computeFn
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
...
...
@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl(
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
stash_fn
(
params
[
si
],
si
);
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
forward_fn
,
device
,
input_buf
,
output_buf
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
0
,
n_groups
*
num_expert
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
++
si
;
}
}
pop_fn
();
pop_fn
(
0
);
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
...
...
@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t
*
evt_reduce
=
new
cudaEvent_t
[
num_expert
];
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
stash_fn
(
si
);
stash_fn
(
si
,
0
);
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
backward_fn
,
device
,
grad_out
,
grad_in
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
collect_fn
(
si
,
i
/
num_expert
);
0
,
n_groups
*
num_expert
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
collect_fn
(
si
,
i
/
num_expert
,
0
);
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
...
...
@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl(
++
si
;
}
}
pop_fn
();
pop_fn
(
0
);
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
1
)
,
input_ready
[
step
]);
FMOE_SWE
(
torch_
stream
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
...
@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl(
computeFn
(
backward_fn
,
device
,
global_grad_out
,
global_grad_in
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
...
...
@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl(
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
);
set_grad_fn
(
si
,
i
%
num_expert
);
}
++
si
;
}
...
...
fmoe/fastermoe/expert_utils.py
View file @
14878015
import
torch
def
get_expert_param_size
(
e
):
def
get_expert_param_size
(
e
,
idx
):
e
=
e
[
idx
]
return
sum
(
map
(
lambda
x
:
x
.
numel
(),
e
.
parameters
()))
def
get_expert_params
(
e
,
out
):
def
get_expert_params
(
e
,
out
,
idx
):
e
=
e
[
idx
]
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
...
...
@@ -13,20 +15,25 @@ def get_expert_params(e, out):
seg
.
copy_
(
p
.
data
.
flatten
())
def
stash_expert_params
(
e
,
params
):
def
stash_expert_params
(
e
,
params
,
idx
):
e
=
e
[
idx
]
if
not
hasattr
(
e
,
'expert_param_stash'
):
setattr
(
e
,
'expert_param_stash'
,
dict
())
setattr
(
e
,
'expert_grad_stash'
,
dict
())
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
if
n
not
in
e
.
expert_param_stash
:
e
.
expert_param_stash
[
n
]
=
p
.
data
.
clone
()
e
.
expert_grad_stash
[
n
]
=
p
.
grad
.
clone
()
if
p
.
grad
is
not
None
else
None
with
torch
.
no_grad
():
seg
=
params
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
p
.
copy_
(
seg
.
reshape
(
p
.
shape
))
p
.
grad
=
None
def
pop_expert_params
(
e
):
def
pop_expert_params
(
e
,
idx
):
e
=
e
[
idx
]
if
not
hasattr
(
e
,
'expert_param_stash'
):
return
if
not
e
.
expert_param_stash
:
...
...
@@ -34,10 +41,14 @@ def pop_expert_params(e):
for
n
,
p
in
e
.
named_parameters
():
with
torch
.
no_grad
():
p
.
copy_
(
e
.
expert_param_stash
[
n
])
if
e
.
expert_grad_stash
[
n
]
is
not
None
:
p
.
grad
=
e
.
expert_grad_stash
[
n
].
clone
()
e
.
expert_param_stash
.
clear
()
e
.
expert_grad_stash
.
clear
()
def
collect_expert_grads
(
e
,
grads
):
def
collect_expert_grads
(
e
,
grads
,
idx
):
e
=
e
[
idx
]
offset
=
0
for
_
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
...
...
@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads):
seg
.
zero_
()
def
set_grads
(
e
,
grads
):
def
set_grads
(
e
,
grads
,
idx
):
e
=
e
[
idx
]
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
...
...
fmoe/fastermoe/schedule.py
View file @
14878015
...
...
@@ -23,12 +23,13 @@ class MoEForward(Function):
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
num_expert
,
world_size
):
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
ctx
.
gibs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
2
)
def
_expert_forward
(
x
,
y
,
idx
):
ctx
.
gibs
=
[
None
]
*
(
world_size
*
num_expert
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
num_expert
*
2
)
def
_expert_forward
(
x
,
y
,
expert_idx
,
store_
idx
):
nothing
=
lambda
a
:
a
x
=
x
.
data
with
torch
.
enable_grad
():
...
...
@@ -40,22 +41,24 @@ class MoEForward(Function):
except
Exception
as
e
:
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
))
ctx
.
gibs
[
idx
]
=
x
ctx
.
gobs
[
idx
]
=
y0
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
)
,
expert_idx
)
ctx
.
gibs
[
store_
idx
]
=
x
ctx
.
gobs
[
store_
idx
]
=
y0
y
.
copy_
(
y0
)
ctx
.
experts
=
experts
if
stored_models
.
any
():
ctx
.
expert_size
=
expert_utils
.
get_expert_param_size
(
experts
)
ctx
.
expert_size
=
expert_utils
.
get_expert_param_size
(
experts
,
0
)
for
i
in
range
(
num_expert
):
assert
ctx
.
expert_size
==
expert_utils
.
get_expert_param_size
(
experts
,
i
),
"report bug"
else
:
ctx
.
expert_size
=
0
get_param_fn
=
lambda
out
:
expert_utils
.
get_expert_params
(
experts
,
out
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
ctx
.
shadows
=
[
None
]
*
world_size
def
stash_fn
(
params
,
idx
):
expert_utils
.
stash_expert_params
(
experts
,
params
)
ctx
.
shadows
[
idx
]
=
params
get_param_fn
=
lambda
out
,
idx
:
expert_utils
.
get_expert_params
(
experts
,
out
,
idx
)
pop_fn
=
lambda
idx
:
expert_utils
.
pop_expert_params
(
experts
,
idx
)
ctx
.
shadows
=
[
None
]
*
world_size
*
num_expert
def
stash_fn
(
params
,
store_idx
,
expert_
idx
):
expert_utils
.
stash_expert_params
(
experts
,
params
,
expert_idx
)
ctx
.
shadows
[
store_
idx
]
=
params
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_input_buf
,
...
...
@@ -71,7 +74,7 @@ class MoEForward(Function):
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
gib
,
local_input_buf
)
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
num_expert
,
world_size
ctx
.
save_for_backward
(
*
variables
)
return
out
...
...
@@ -80,23 +83,23 @@ class MoEForward(Function):
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
_1
,
_2
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
(
fwd_batch_size
,
inp_batch_size
,
num_expert
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad_y
,
grad_x
,
idx
):
y
=
ctx
.
gobs
[
idx
]
x
=
ctx
.
gibs
[
idx
]
def
_expert_backward
(
grad_y
,
grad_x
,
expert_idx
,
store_
idx
):
y
=
ctx
.
gobs
[
store_
idx
]
x
=
ctx
.
gibs
[
store_
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad_y
])
grad_x
.
copy_
(
x
.
grad
)
experts
=
ctx
.
experts
def
stash_fn
(
idx
):
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
idx
]
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
def
collect_fn
(
idx
,
root
):
grad
=
ctx
.
shadows
[
idx
]
expert_utils
.
collect_expert_grads
(
experts
,
grad
)
def
stash_fn
(
store_idx
,
expert_
idx
):
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
store_idx
],
expert_
idx
)
pop_fn
=
lambda
idx
:
expert_utils
.
pop_expert_params
(
experts
,
idx
)
def
collect_fn
(
store_
idx
,
root
,
expert_idx
):
grad
=
ctx
.
shadows
[
store_
idx
]
expert_utils
.
collect_expert_grads
(
experts
,
grad
,
expert_idx
)
fmoe_native
.
reduce_grad
(
grad
,
root
,
ctx
.
expert_size
)
set_grad_fn
=
lambda
idx
:
expert_utils
.
set_grads
(
experts
,
ctx
.
shadows
[
idx
]
)
set_grad_fn
=
lambda
store_idx
,
expert_
idx
:
expert_utils
.
set_grads
(
experts
,
ctx
.
shadows
[
store_idx
],
expert_
idx
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
...
...
@@ -108,7 +111,7 @@ class MoEForward(Function):
_expert_backward
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
policy_fn
=
None
...
...
@@ -117,8 +120,6 @@ policy_fn = None
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
# TODO: Support many experts on each process
assert
(
n_expert
==
1
)
(
pos
,
local_expert_count
,
...
...
@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
return
MoEForward
.
apply
(
expert_fn
,
experts
,
inp
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
),
pos
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
world_size
)
fwd_batch_size
,
out_batch_size
,
n_expert
,
world_size
)
fmoe/layers.py
View file @
14878015
...
...
@@ -159,16 +159,24 @@ class FMoE(nn.Module):
if
self
.
experts_fused
:
return
self
.
experts
(
inp
,
fwd_expert_count
)
if
isinstance
(
fwd_expert_count
,
torch
.
Tensor
):
fwd_expert_count
=
fwd_expert_count
.
cpu
().
numpy
()
fwd_expert_count
_cpu
=
fwd_expert_count
.
cpu
().
numpy
()
outputs
=
[]
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
]
batch_size
=
fwd_expert_count
_cpu
[
i
]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
,
torch
.
tensor
([
fwd_expert_count
[
i
]])
))
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
expert_fn_single
(
self
,
inp
,
fwd_expert_count
,
idx
):
r
"""
forward single expert for smart scheduling.
"""
assert
not
self
.
experts_fused
,
"should not use fused experts"
output
=
self
.
experts
[
idx
](
inp
,
fwd_expert_count
)
return
output
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
"none"
):
r
"""
Automatically mark the data parallel comms of the parameters within the
...
...
@@ -231,7 +239,7 @@ class FMoE(nn.Module):
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn_single
if
fmoe_faster_schedule
else
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
,
experts
=
self
.
experts
)
...
...
fmoe/megatron/layers.py
View file @
14878015
...
...
@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used.
"""
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
if
type
(
self
.
experts
)
is
nn
.
ModuleList
:
for
expert
in
self
.
experts
:
_megatron_init_method
(
expert
.
htoh4
,
rng
,
self
.
sigma
)
else
:
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
if
type
(
self
.
experts
)
is
nn
.
ModuleList
:
for
expert
in
self
.
experts
:
_megatron_init_method
(
expert
.
h4toh
,
rng
,
std
)
else
:
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
def
forward
(
self
,
inp
):
from
megatron
import
mpu
...
...
fmoe/transformer.py
View file @
14878015
...
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
from
.layers
import
FMoE
from
.linear
import
FMoELinear
from
.fastermoe.config
import
switch_from_env
class
_Expert
(
nn
.
Module
):
...
...
@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE):
expert_rank
=
0
,
**
kwargs
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
**
kwargs
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
expert_rank
)
def
one_expert
(
d_model
):
return
_Expert
(
1
,
d_model
,
d_hidden
,
activation
,
rank
=
0
)
expert
=
one_expert
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
expert
=
expert
,
**
kwargs
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment