Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d0f07ff7
Commit
d0f07ff7
authored
Jan 25, 2021
by
Rick Ho
Browse files
basic megatron support frame
parent
832385c2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
51 additions
and
20 deletions
+51
-20
fmoe/megatron.py
fmoe/megatron.py
+12
-0
fmoe/moe.py
fmoe/moe.py
+17
-0
fmoe/moe_function.py
fmoe/moe_function.py
+18
-18
setup.py
setup.py
+3
-2
tests/moe_test.py
tests/moe_test.py
+1
-0
No files found.
fmoe/megatron.py
0 → 100644
View file @
d0f07ff7
from
torch
import
nn
from
.moe
import
FFFN
def
create_moe_mlp
(
args
):
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
FFFN
(
num_experts
,
in_feat
=
args
.
hidden_size
,
hidden_feat
=
args
.
hidden_size
*
4
,
out_feat
=
args
.
hidden_size
,
world_size
=
args
.
model_parallel_size
)
return
fmoe
fmoe/moe.py
View file @
d0f07ff7
...
@@ -26,6 +26,23 @@ class FMoE(nn.Module):
...
@@ -26,6 +26,23 @@ class FMoE(nn.Module):
return
moe
(
inp
,
gate
.
int
(),
self
.
weight
,
self
.
world_size
)
return
moe
(
inp
,
gate
.
int
(),
self
.
weight
,
self
.
world_size
)
class
FFFN
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
out_feat
=
1024
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
):
super
(
FFFN
,
self
).
__init__
()
self
.
htoh4
=
FMoE
(
num_expert
,
in_feat
,
hidden_feat
,
world_size
=
world_size
)
self
.
activation
=
activation
self
.
h4toh
=
FMoE
(
num_expert
,
hidden_feat
,
out_feat
,
world_size
=
world_size
)
def
forward
(
self
,
inp
,
gate
):
x
=
self
.
htoh4
(
inp
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
)
return
x
class
BruteForceMoE
(
nn
.
Module
):
class
BruteForceMoE
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
world_size
=
0
):
world_size
=
0
):
...
...
fmoe/moe_function.py
View file @
d0f07ff7
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
import
moe_cuda
import
f
moe_cuda
class
MOELocal
(
Function
):
class
MOELocal
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
expert_count
,
pos
=
moe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
])
expert_count
,
pos
=
f
moe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
])
input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
input_buf
,
=
f
moe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output_buf
,
=
f
moe_cuda
.
forward
(
input_buf
,
weight
,
expert_count
)
output
=
moe_cuda
.
local_gather
(
output_buf
,
pos
)
output
=
f
moe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
expert_count
,
pos
]
variables
=
[
input_buf
,
gate
,
weight
,
expert_count
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
...
@@ -20,10 +20,10 @@ class MOELocal(Function):
...
@@ -20,10 +20,10 @@ class MOELocal(Function):
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
input_buf
,
gate
,
weight
,
expert_count
,
pos
=
ctx
.
saved_tensors
input_buf
,
gate
,
weight
,
expert_count
,
pos
=
ctx
.
saved_tensors
grad_out_buf
,
=
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
,
=
f
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_inp_buf
,
grad_weight
=
moe_cuda
.
backward
(
grad_inp_buf
,
grad_weight
=
f
moe_cuda
.
backward
(
grad_out_buf
,
input_buf
,
weight
,
expert_count
)
grad_out_buf
,
input_buf
,
weight
,
expert_count
)
grad_inp
,
=
moe_cuda
.
local_gather
(
grad_inp_buf
,
pos
)
grad_inp
,
=
f
moe_cuda
.
local_gather
(
grad_inp_buf
,
pos
)
return
grad_inp
,
None
,
grad_weight
return
grad_inp
,
None
,
grad_weight
...
@@ -33,20 +33,20 @@ class MOEGlobal(Function):
...
@@ -33,20 +33,20 @@ class MOEGlobal(Function):
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
def
forward
(
ctx
,
inp
,
gate
,
weight
,
world_size
):
num_expert
=
weight
.
shape
[
0
]
num_expert
=
weight
.
shape
[
0
]
local_expert_count
,
pos
=
moe_cuda
.
expert_count
(
gate
,
local_expert_count
,
pos
=
f
moe_cuda
.
expert_count
(
gate
,
world_size
*
num_expert
)
world_size
*
num_expert
)
global_expert_count
,
fwd_expert_count
=
moe_cuda
.
expert_exchange
(
global_expert_count
,
fwd_expert_count
=
f
moe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
local_expert_count
,
num_expert
,
world_size
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
local_input_buf
,
=
f
moe_cuda
.
local_scatter
(
inp
,
pos
)
local_output_buf
,
global_input_buf
=
moe_cuda
.
global_fused_forward
(
local_output_buf
,
global_input_buf
=
f
moe_cuda
.
global_fused_forward
(
local_input_buf
,
weight
,
local_input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
)
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
)
output
,
=
moe_cuda
.
local_gather
(
local_output_buf
,
pos
)
output
,
=
f
moe_cuda
.
local_gather
(
local_output_buf
,
pos
)
variables
=
(
global_input_buf
,
gate
,
weight
,
variables
=
(
global_input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
...
@@ -63,18 +63,18 @@ class MOEGlobal(Function):
...
@@ -63,18 +63,18 @@ class MOEGlobal(Function):
pos
)
=
ctx
.
saved_tensors
pos
)
=
ctx
.
saved_tensors
num_expert
,
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
num_expert
,
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
,
=
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
,
=
f
moe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
global_grad_out_buf
,
=
moe_cuda
.
global_scatter
(
grad_out_buf
,
global_grad_out_buf
,
=
f
moe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
fwd_batch_size
,
world_size
)
grad_inp_buf
,
grad_weight
=
moe_cuda
.
backward
(
grad_inp_buf
,
grad_weight
=
f
moe_cuda
.
backward
(
global_grad_out_buf
,
input_buf
,
weight
,
fwd_expert_count
)
global_grad_out_buf
,
input_buf
,
weight
,
fwd_expert_count
)
local_grad_inp_buf
,
=
moe_cuda
.
global_gather
(
grad_inp_buf
,
local_grad_inp_buf
,
=
f
moe_cuda
.
global_gather
(
grad_inp_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
local_batch_size
,
world_size
)
grad_inp
,
=
moe_cuda
.
local_gather
(
local_grad_inp_buf
,
pos
)
grad_inp
,
=
f
moe_cuda
.
local_gather
(
local_grad_inp_buf
,
pos
)
return
grad_inp
,
None
,
grad_weight
,
None
return
grad_inp
,
None
,
grad_weight
,
None
...
...
setup.py
View file @
d0f07ff7
...
@@ -12,8 +12,8 @@ if os.environ.get('USE_NCCL', '0') == '1':
...
@@ -12,8 +12,8 @@ if os.environ.get('USE_NCCL', '0') == '1':
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
setuptools
.
setup
(
setuptools
.
setup
(
name
=
'fmoe
_cuda
'
,
name
=
'fmoe'
,
packages
=
setuptools
.
find_packages
()
,
packages
=
[
'fmoe'
]
,
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
CUDAExtension
(
name
=
'fmoe_cuda'
,
name
=
'fmoe_cuda'
,
...
@@ -30,6 +30,7 @@ if __name__ == '__main__':
...
@@ -30,6 +30,7 @@ if __name__ == '__main__':
}
}
)
)
],
],
version
=
'0.0.1'
,
cmdclass
=
{
cmdclass
=
{
'build_ext'
:
BuildExtension
'build_ext'
:
BuildExtension
})
})
tests/moe_test.py
View file @
d0f07ff7
...
@@ -159,6 +159,7 @@ def test_dp():
...
@@ -159,6 +159,7 @@ def test_dp():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
len
(
sys
.
argv
)
>=
2
:
if
len
(
sys
.
argv
)
>=
2
:
task
=
sys
.
argv
[
1
]
task
=
sys
.
argv
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment