Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d81cb62f
Unverified
Commit
d81cb62f
authored
Mar 25, 2021
by
Jiezhong Qiu
Committed by
GitHub
Mar 25, 2021
Browse files
Merge pull request #22 from laekov/laekov/fix-tests
fix tests after updating megatron
parents
cac233f3
6868ed2a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
1 deletion
+39
-1
tests/test_numerical.py
tests/test_numerical.py
+1
-1
tests/test_zero.py
tests/test_zero.py
+38
-0
No files found.
tests/test_numerical.py
View file @
d81cb62f
...
@@ -12,7 +12,7 @@ from fmoe.gates import NaiveGate
...
@@ -12,7 +12,7 @@ from fmoe.gates import NaiveGate
from
fmoe.layers
import
FMoE
from
fmoe.layers
import
FMoE
from
fmoe.transformer
import
_Expert
from
fmoe.transformer
import
_Expert
from
fmoe.distributed
import
DistributedGroupedDataParallel
as
LocalDDP
from
fmoe.distributed
import
DistributedGroupedDataParallel
as
LocalDDP
from
fmoe.megatron
import
_megatron_init_method
from
fmoe.megatron
.layers
import
_megatron_init_method
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
...
...
tests/test_zero.py
0 → 100644
View file @
d81cb62f
import
torch
from
fmoe.layers
import
_fmoe_general_global_forward
from
fmoe
import
FMoETransformerMLP
class
ConstantGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
1
):
super
().
__init__
()
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
idx
=
torch
.
zeros
((
inp
.
shape
[
0
]
*
self
.
top_k
,),
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
score
=
torch
.
ones
((
inp
.
shape
[
0
],
1
,
self
.
top_k
),
device
=
inp
.
device
)
/
2
return
idx
,
score
,
None
def
test_zero_fwd
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
gate
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
).
cuda
()
x
=
_fmoe_general_global_forward
(
inp
,
gate
,
lambda
x
,
y
:
x
,
num_expert
,
world_size
)
def
test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
model
=
FMoETransformerMLP
(
num_expert
,
d_hidden
,
d_hidden
*
4
,
world_size
,
gate
=
ConstantGate
).
cuda
()
oup
=
model
(
inp
)
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
torch
.
distributed
.
get_rank
())
# test_zero_fwd(world_size=torch.distributed.get_world_size())
test_zero_transformer
(
num_expert
=
16
,
batch_size
=
4096
,
d_hidden
=
1024
,
world_size
=
torch
.
distributed
.
get_world_size
())
print
(
'done'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment