Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
df715c9f
Unverified
Commit
df715c9f
authored
Mar 21, 2023
by
Rick Ho
Committed by
GitHub
Mar 21, 2023
Browse files
Merge pull request #146 from laekov/fastermoe_multi_experts
make FasterMoE more general
parents
698a12ae
14878015
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
94 additions
and
60 deletions
+94
-60
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+1
-1
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+14
-14
fmoe/fastermoe/expert_utils.py
fmoe/fastermoe/expert_utils.py
+18
-6
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+30
-29
fmoe/layers.py
fmoe/layers.py
+12
-4
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+13
-2
fmoe/transformer.py
fmoe/transformer.py
+6
-4
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
df715c9f
...
@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
...
@@ -104,7 +104,7 @@ std::vector<torch::Tensor> _smart_sch_forward(
if
(
stored_models_
[
i
])
{
if
(
stored_models_
[
i
])
{
torch
::
Tensor
t
=
input_buf
.
new_empty
({
expert_size
});
torch
::
Tensor
t
=
input_buf
.
new_empty
({
expert_size
});
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
get_param_fn
(
t
);
get_param_fn
(
t
,
i
%
num_expert
);
}
}
params
.
push_back
(
t
);
params
.
push_back
(
t
);
}
}
...
...
cuda/fastermoe/smart_schedule.h
View file @
df715c9f
...
@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
...
@@ -83,7 +83,7 @@ void computePtrs(long num_expert, long rank, long world_size,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
computeFn
(
py
::
function
fn
,
c10
::
Device
device
,
void
computeFn
(
py
::
function
fn
,
c10
::
Device
device
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
scalar_t
*
inp_buf
,
scalar_t
*
out_buf
,
long
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
long
expert_idx
,
long
store_
idx
,
long
offset
,
long
micro_batch_size
,
long
d_model
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
if
(
micro_batch_size
==
0
)
{
if
(
micro_batch_size
==
0
)
{
return
;
return
;
...
@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device,
...
@@ -97,7 +97,7 @@ void computeFn(py::function fn, c10::Device device,
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
auto
oup
=
torch
::
from_blob
(
out_buf
+
offset
*
d_model
,
{
micro_batch_size
,
d_model
},
options
);
{
micro_batch_size
,
d_model
},
options
);
smgr
->
use_default
=
true
;
smgr
->
use_default
=
true
;
fn
(
inp
,
oup
,
idx
);
fn
(
inp
,
oup
,
expert_idx
,
store_
idx
);
smgr
->
use_default
=
false
;
smgr
->
use_default
=
false
;
}
}
...
@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -174,7 +174,7 @@ void fmoe_cuda_fused_forward_impl(
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
&
evt_get
);
cudaEventCreate
(
&
evt_get
);
cudaEventRecord
(
evt_get
,
torch_stream
);
cudaEventRecord
(
evt_get
,
torch_stream
);
FMOE_SWE
(
smgr
->
stream
(
1
),
evt_get
);
FMOE_SWE
(
smgr
->
stream
(
0
),
evt_get
);
cudaEventDestroy
(
evt_get
);
cudaEventDestroy
(
evt_get
);
}
}
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
NCCL_SAFE_CALL
(
ncclBcast
((
void
*
)
params
[
si
].
data_ptr
<
scalar_t
>
(),
...
@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -196,7 +196,7 @@ void fmoe_cuda_fused_forward_impl(
(
from_base
+
pipeline_gran
)]
-
offset
;
(
from_base
+
pipeline_gran
)]
-
offset
;
computeFn
(
forward_fn
,
device
,
computeFn
(
forward_fn
,
device
,
global_input_buf
,
global_output_buf
,
global_input_buf
,
global_output_buf
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
}
...
@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl(
...
@@ -204,17 +204,17 @@ void fmoe_cuda_fused_forward_impl(
// Compute over shadowed experts
// Compute over shadowed experts
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
stash_fn
(
params
[
si
],
si
);
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
FMOE_SWE
(
torch_stream
,
evt_shadow
[
si
]);
stash_fn
(
params
[
si
],
si
,
0
);
// always put shadowed expert at first, so expert_idx = 0
long
offset
=
local_ptr
[
i
];
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
forward_fn
,
device
,
computeFn
(
forward_fn
,
device
,
input_buf
,
output_buf
,
input_buf
,
output_buf
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
0
,
n_groups
*
num_expert
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
++
si
;
++
si
;
}
}
}
}
pop_fn
();
pop_fn
(
0
);
// R_0 ... R_n
// R_0 ... R_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
...
@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -319,13 +319,13 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t
*
evt_reduce
=
new
cudaEvent_t
[
num_expert
];
cudaEvent_t
*
evt_reduce
=
new
cudaEvent_t
[
num_expert
];
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
for
(
long
i
=
0
,
si
=
0
;
i
<
world_size
*
num_expert
;
++
i
)
{
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
stash_fn
(
si
);
stash_fn
(
si
,
0
);
long
offset
=
local_ptr
[
i
];
long
offset
=
local_ptr
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
long
micro_batch_size
=
local_expert_count
[
i
];
computeFn
(
backward_fn
,
device
,
computeFn
(
backward_fn
,
device
,
grad_out
,
grad_in
,
grad_out
,
grad_in
,
n_groups
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
0
,
n_groups
*
num_expert
+
si
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
collect_fn
(
si
,
i
/
num_expert
);
collect_fn
(
si
,
i
/
num_expert
,
0
);
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventCreate
(
evt_reduce
+
i
%
num_expert
);
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
cudaEventRecord
(
evt_reduce
[
i
%
num_expert
],
smgr
->
stream
(
0
));
...
@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -333,11 +333,11 @@ void fmoe_cuda_fused_backward_impl(
++
si
;
++
si
;
}
}
}
}
pop_fn
();
pop_fn
(
0
);
// C_0 ... C_n
// C_0 ... C_n
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
for
(
long
step
=
0
;
step
<
n_groups
;
++
step
)
{
FMOE_SWE
(
smgr
->
stream
(
1
)
,
input_ready
[
step
]);
FMOE_SWE
(
torch_
stream
,
input_ready
[
step
]);
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
for
(
int
ei
=
0
;
ei
<
num_expert
;
++
ei
)
{
GEN_BASE
(
step
);
GEN_BASE
(
step
);
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
long
offset
=
global_ptr
[
ei
*
world_size
+
from_base
];
...
@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -346,7 +346,7 @@ void fmoe_cuda_fused_backward_impl(
computeFn
(
backward_fn
,
device
,
computeFn
(
backward_fn
,
device
,
global_grad_out
,
global_grad_in
,
global_grad_out
,
global_grad_in
,
step
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
(
long
)
ei
,
step
*
num_expert
+
ei
,
offset
,
micro_batch_size
,
d_model
,
smgr
);
}
}
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
cudaEventRecord
(
output_ready
[
step
],
torch_stream
);
}
}
...
@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl(
...
@@ -356,7 +356,7 @@ void fmoe_cuda_fused_backward_impl(
if
(
stored_models
[
i
])
{
if
(
stored_models
[
i
])
{
if
(
i
/
num_expert
==
rank
)
{
if
(
i
/
num_expert
==
rank
)
{
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
FMOE_SWE
(
torch_stream
,
evt_reduce
[
i
%
num_expert
]);
set_grad_fn
(
si
);
set_grad_fn
(
si
,
i
%
num_expert
);
}
}
++
si
;
++
si
;
}
}
...
...
fmoe/fastermoe/expert_utils.py
View file @
df715c9f
import
torch
import
torch
def
get_expert_param_size
(
e
):
def
get_expert_param_size
(
e
,
idx
):
e
=
e
[
idx
]
return
sum
(
map
(
lambda
x
:
x
.
numel
(),
e
.
parameters
()))
return
sum
(
map
(
lambda
x
:
x
.
numel
(),
e
.
parameters
()))
def
get_expert_params
(
e
,
out
):
def
get_expert_params
(
e
,
out
,
idx
):
e
=
e
[
idx
]
offset
=
0
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
seg
=
out
[
offset
:
offset
+
p
.
numel
()]
...
@@ -13,20 +15,25 @@ def get_expert_params(e, out):
...
@@ -13,20 +15,25 @@ def get_expert_params(e, out):
seg
.
copy_
(
p
.
data
.
flatten
())
seg
.
copy_
(
p
.
data
.
flatten
())
def
stash_expert_params
(
e
,
params
):
def
stash_expert_params
(
e
,
params
,
idx
):
e
=
e
[
idx
]
if
not
hasattr
(
e
,
'expert_param_stash'
):
if
not
hasattr
(
e
,
'expert_param_stash'
):
setattr
(
e
,
'expert_param_stash'
,
dict
())
setattr
(
e
,
'expert_param_stash'
,
dict
())
setattr
(
e
,
'expert_grad_stash'
,
dict
())
offset
=
0
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
if
n
not
in
e
.
expert_param_stash
:
if
n
not
in
e
.
expert_param_stash
:
e
.
expert_param_stash
[
n
]
=
p
.
data
.
clone
()
e
.
expert_param_stash
[
n
]
=
p
.
data
.
clone
()
e
.
expert_grad_stash
[
n
]
=
p
.
grad
.
clone
()
if
p
.
grad
is
not
None
else
None
with
torch
.
no_grad
():
with
torch
.
no_grad
():
seg
=
params
[
offset
:
offset
+
p
.
numel
()]
seg
=
params
[
offset
:
offset
+
p
.
numel
()]
offset
+=
p
.
numel
()
offset
+=
p
.
numel
()
p
.
copy_
(
seg
.
reshape
(
p
.
shape
))
p
.
copy_
(
seg
.
reshape
(
p
.
shape
))
p
.
grad
=
None
def
pop_expert_params
(
e
):
def
pop_expert_params
(
e
,
idx
):
e
=
e
[
idx
]
if
not
hasattr
(
e
,
'expert_param_stash'
):
if
not
hasattr
(
e
,
'expert_param_stash'
):
return
return
if
not
e
.
expert_param_stash
:
if
not
e
.
expert_param_stash
:
...
@@ -34,10 +41,14 @@ def pop_expert_params(e):
...
@@ -34,10 +41,14 @@ def pop_expert_params(e):
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
with
torch
.
no_grad
():
with
torch
.
no_grad
():
p
.
copy_
(
e
.
expert_param_stash
[
n
])
p
.
copy_
(
e
.
expert_param_stash
[
n
])
if
e
.
expert_grad_stash
[
n
]
is
not
None
:
p
.
grad
=
e
.
expert_grad_stash
[
n
].
clone
()
e
.
expert_param_stash
.
clear
()
e
.
expert_param_stash
.
clear
()
e
.
expert_grad_stash
.
clear
()
def
collect_expert_grads
(
e
,
grads
):
def
collect_expert_grads
(
e
,
grads
,
idx
):
e
=
e
[
idx
]
offset
=
0
offset
=
0
for
_
,
p
in
e
.
named_parameters
():
for
_
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
...
@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads):
...
@@ -49,7 +60,8 @@ def collect_expert_grads(e, grads):
seg
.
zero_
()
seg
.
zero_
()
def
set_grads
(
e
,
grads
):
def
set_grads
(
e
,
grads
,
idx
):
e
=
e
[
idx
]
offset
=
0
offset
=
0
for
n
,
p
in
e
.
named_parameters
():
for
n
,
p
in
e
.
named_parameters
():
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
seg
=
grads
[
offset
:
offset
+
p
.
numel
()]
...
...
fmoe/fastermoe/schedule.py
View file @
df715c9f
...
@@ -23,12 +23,13 @@ class MoEForward(Function):
...
@@ -23,12 +23,13 @@ class MoEForward(Function):
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
stored_models
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
fwd_batch_size
,
out_batch_size
,
num_expert
,
world_size
):
world_size
):
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
local_input_buf
=
_local_scatter
(
inp
,
pos_s
)
ctx
.
gibs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gibs
=
[
None
]
*
(
world_size
*
num_expert
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
2
)
ctx
.
gobs
=
[
None
]
*
(
world_size
*
num_expert
*
2
)
def
_expert_forward
(
x
,
y
,
idx
):
def
_expert_forward
(
x
,
y
,
expert_idx
,
store_
idx
):
nothing
=
lambda
a
:
a
nothing
=
lambda
a
:
a
x
=
x
.
data
x
=
x
.
data
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
...
@@ -40,22 +41,24 @@ class MoEForward(Function):
...
@@ -40,22 +41,24 @@ class MoEForward(Function):
except
Exception
as
e
:
except
Exception
as
e
:
# Ignore the error and fall back for compatibility to older
# Ignore the error and fall back for compatibility to older
# versions of PyTorch
# versions of PyTorch
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
))
y0
=
expert_fn
(
x
,
torch
.
tensor
([
x
.
shape
[
0
]],
dtype
=
torch
.
int64
)
,
expert_idx
)
ctx
.
gibs
[
idx
]
=
x
ctx
.
gibs
[
store_
idx
]
=
x
ctx
.
gobs
[
idx
]
=
y0
ctx
.
gobs
[
store_
idx
]
=
y0
y
.
copy_
(
y0
)
y
.
copy_
(
y0
)
ctx
.
experts
=
experts
ctx
.
experts
=
experts
if
stored_models
.
any
():
if
stored_models
.
any
():
ctx
.
expert_size
=
expert_utils
.
get_expert_param_size
(
experts
)
ctx
.
expert_size
=
expert_utils
.
get_expert_param_size
(
experts
,
0
)
for
i
in
range
(
num_expert
):
assert
ctx
.
expert_size
==
expert_utils
.
get_expert_param_size
(
experts
,
i
),
"report bug"
else
:
else
:
ctx
.
expert_size
=
0
ctx
.
expert_size
=
0
get_param_fn
=
lambda
out
:
expert_utils
.
get_expert_params
(
experts
,
out
)
get_param_fn
=
lambda
out
,
idx
:
expert_utils
.
get_expert_params
(
experts
,
out
,
idx
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
pop_fn
=
lambda
idx
:
expert_utils
.
pop_expert_params
(
experts
,
idx
)
ctx
.
shadows
=
[
None
]
*
world_size
ctx
.
shadows
=
[
None
]
*
world_size
*
num_expert
def
stash_fn
(
params
,
idx
):
def
stash_fn
(
params
,
store_idx
,
expert_
idx
):
expert_utils
.
stash_expert_params
(
experts
,
params
)
expert_utils
.
stash_expert_params
(
experts
,
params
,
expert_idx
)
ctx
.
shadows
[
idx
]
=
params
ctx
.
shadows
[
store_
idx
]
=
params
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_input_buf
,
local_input_buf
,
...
@@ -71,7 +74,7 @@ class MoEForward(Function):
...
@@ -71,7 +74,7 @@ class MoEForward(Function):
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
gib
,
local_input_buf
)
stored_models
,
gib
,
local_input_buf
)
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
num_expert
,
world_size
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
out
return
out
...
@@ -80,23 +83,23 @@ class MoEForward(Function):
...
@@ -80,23 +83,23 @@ class MoEForward(Function):
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
,
_1
,
_2
)
=
ctx
.
saved_tensors
stored_models
,
_1
,
_2
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
(
fwd_batch_size
,
inp_batch_size
,
num_expert
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad_y
,
grad_x
,
idx
):
def
_expert_backward
(
grad_y
,
grad_x
,
expert_idx
,
store_
idx
):
y
=
ctx
.
gobs
[
idx
]
y
=
ctx
.
gobs
[
store_
idx
]
x
=
ctx
.
gibs
[
idx
]
x
=
ctx
.
gibs
[
store_
idx
]
torch
.
autograd
.
backward
([
y
],
[
grad_y
])
torch
.
autograd
.
backward
([
y
],
[
grad_y
])
grad_x
.
copy_
(
x
.
grad
)
grad_x
.
copy_
(
x
.
grad
)
experts
=
ctx
.
experts
experts
=
ctx
.
experts
def
stash_fn
(
idx
):
def
stash_fn
(
store_idx
,
expert_
idx
):
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
idx
]
)
expert_utils
.
stash_expert_params
(
experts
,
ctx
.
shadows
[
store_idx
],
expert_
idx
)
pop_fn
=
lambda
:
expert_utils
.
pop_expert_params
(
experts
)
pop_fn
=
lambda
idx
:
expert_utils
.
pop_expert_params
(
experts
,
idx
)
def
collect_fn
(
idx
,
root
):
def
collect_fn
(
store_
idx
,
root
,
expert_idx
):
grad
=
ctx
.
shadows
[
idx
]
grad
=
ctx
.
shadows
[
store_
idx
]
expert_utils
.
collect_expert_grads
(
experts
,
grad
)
expert_utils
.
collect_expert_grads
(
experts
,
grad
,
expert_idx
)
fmoe_native
.
reduce_grad
(
grad
,
root
,
ctx
.
expert_size
)
fmoe_native
.
reduce_grad
(
grad
,
root
,
ctx
.
expert_size
)
set_grad_fn
=
lambda
idx
:
expert_utils
.
set_grads
(
experts
,
ctx
.
shadows
[
idx
]
)
set_grad_fn
=
lambda
store_idx
,
expert_
idx
:
expert_utils
.
set_grads
(
experts
,
ctx
.
shadows
[
store_idx
],
expert_
idx
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_out_buf
=
_local_scatter
(
grad_out
.
contiguous
(),
pos_g
)
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
grad_in_buf
=
fmoe_native
.
smart_sch_backward
(
...
@@ -108,7 +111,7 @@ class MoEForward(Function):
...
@@ -108,7 +111,7 @@ class MoEForward(Function):
_expert_backward
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
_expert_backward
,
stash_fn
,
pop_fn
,
collect_fn
,
set_grad_fn
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
grad_in
=
_local_gather
(
grad_in_buf
,
pos_s
,
inp_batch_size
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
return
(
None
,
None
,
grad_in
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
policy_fn
=
None
policy_fn
=
None
...
@@ -117,8 +120,6 @@ policy_fn = None
...
@@ -117,8 +120,6 @@ policy_fn = None
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
def
_fmoe_general_global_forward
(
inp
,
gate
,
expert_fn
,
n_expert
,
world_size
,
experts
=
None
,
stored_models
=
None
):
# TODO: Using multiple tensors as input is to be supported.
# TODO: Using multiple tensors as input is to be supported.
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
assert
(
isinstance
(
inp
,
torch
.
Tensor
))
# TODO: Support many experts on each process
assert
(
n_expert
==
1
)
(
(
pos
,
pos
,
local_expert_count
,
local_expert_count
,
...
@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
...
@@ -143,4 +144,4 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, exp
return
MoEForward
.
apply
(
expert_fn
,
experts
,
inp
,
return
MoEForward
.
apply
(
expert_fn
,
experts
,
inp
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
),
pos
,
torch
.
div
(
pos
,
topk
,
rounding_mode
=
'floor'
),
pos
,
local_expert_count
,
global_expert_count
,
stored_models
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
out_batch_size
,
world_size
)
fwd_batch_size
,
out_batch_size
,
n_expert
,
world_size
)
fmoe/layers.py
View file @
df715c9f
...
@@ -159,16 +159,24 @@ class FMoE(nn.Module):
...
@@ -159,16 +159,24 @@ class FMoE(nn.Module):
if
self
.
experts_fused
:
if
self
.
experts_fused
:
return
self
.
experts
(
inp
,
fwd_expert_count
)
return
self
.
experts
(
inp
,
fwd_expert_count
)
if
isinstance
(
fwd_expert_count
,
torch
.
Tensor
):
if
isinstance
(
fwd_expert_count
,
torch
.
Tensor
):
fwd_expert_count
=
fwd_expert_count
.
cpu
().
numpy
()
fwd_expert_count
_cpu
=
fwd_expert_count
.
cpu
().
numpy
()
outputs
=
[]
outputs
=
[]
base_idx
=
0
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
batch_size
=
fwd_expert_count
[
i
]
batch_size
=
fwd_expert_count
_cpu
[
i
]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
,
torch
.
tensor
([
fwd_expert_count
[
i
]])
))
base_idx
+=
batch_size
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
expert_fn_single
(
self
,
inp
,
fwd_expert_count
,
idx
):
r
"""
forward single expert for smart scheduling.
"""
assert
not
self
.
experts_fused
,
"should not use fused experts"
output
=
self
.
experts
[
idx
](
inp
,
fwd_expert_count
)
return
output
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
"none"
):
def
mark_parallel_comm
(
self
,
expert_dp_comm
=
"none"
):
r
"""
r
"""
Automatically mark the data parallel comms of the parameters within the
Automatically mark the data parallel comms of the parameters within the
...
@@ -231,7 +239,7 @@ class FMoE(nn.Module):
...
@@ -231,7 +239,7 @@ class FMoE(nn.Module):
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
fwd
=
_fmoe_general_global_forward
(
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn_single
if
fmoe_faster_schedule
else
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
,
self
.
num_expert
,
self
.
world_size
,
experts
=
self
.
experts
experts
=
self
.
experts
)
)
...
...
fmoe/megatron/layers.py
View file @
df715c9f
...
@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -130,9 +130,20 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used.
additional numpy rng is used.
"""
"""
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
if
type
(
self
.
experts
)
is
nn
.
ModuleList
:
for
expert
in
self
.
experts
:
_megatron_init_method
(
expert
.
htoh4
,
rng
,
self
.
sigma
)
else
:
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
if
type
(
self
.
experts
)
is
nn
.
ModuleList
:
for
expert
in
self
.
experts
:
_megatron_init_method
(
expert
.
h4toh
,
rng
,
std
)
else
:
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
from
megatron
import
mpu
from
megatron
import
mpu
...
...
fmoe/transformer.py
View file @
df715c9f
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.layers
import
FMoE
from
.layers
import
FMoE
from
.linear
import
FMoELinear
from
.linear
import
FMoELinear
from
.fastermoe.config
import
switch_from_env
class
_Expert
(
nn
.
Module
):
class
_Expert
(
nn
.
Module
):
...
@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,10 +48,11 @@ class FMoETransformerMLP(FMoE):
expert_rank
=
0
,
expert_rank
=
0
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
**
kwargs
)
def
one_expert
(
d_model
):
self
.
experts
=
_Expert
(
return
_Expert
(
1
,
d_model
,
d_hidden
,
activation
,
rank
=
0
)
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
expert_rank
)
expert
=
one_expert
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
expert
=
expert
,
**
kwargs
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment