Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
c931c484
Commit
c931c484
authored
May 20, 2021
by
Rich Ho
Browse files
fix zero test
parent
5680c599
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
2 deletions
+7
-2
tests/test_zero.py
tests/test_zero.py
+7
-2
No files found.
tests/test_zero.py
View file @
c931c484
import
os
import
sys
import
sys
import
json
import
torch
import
torch
from
fmoe.layers
import
_fmoe_general_global_forward
from
fmoe.layers
import
_fmoe_general_global_forward
from
fmoe
import
FMoETransformerMLP
from
fmoe
import
FMoETransformerMLP
...
@@ -12,7 +14,7 @@ class ConstantGate(torch.nn.Module):
...
@@ -12,7 +14,7 @@ class ConstantGate(torch.nn.Module):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
):
def
forward
(
self
,
inp
):
idx
=
torch
.
zeros
((
inp
.
shape
[
0
]
*
self
.
top_k
,
),
dtype
=
torch
.
int64
,
idx
=
torch
.
zeros
((
inp
.
shape
[
0
]
,
self
.
top_k
),
dtype
=
torch
.
int64
,
device
=
inp
.
device
)
device
=
inp
.
device
)
score
=
torch
.
ones
((
inp
.
shape
[
0
],
1
,
self
.
top_k
),
device
=
inp
.
device
)
/
2
score
=
torch
.
ones
((
inp
.
shape
[
0
],
1
,
self
.
top_k
),
device
=
inp
.
device
)
/
2
return
idx
,
score
return
idx
,
score
...
@@ -47,7 +49,7 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
...
@@ -47,7 +49,7 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
script
=
__file__
script
=
__file__
)
)
def
test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
def
_
test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
model
=
FMoETransformerMLP
(
num_expert
,
d_hidden
,
d_hidden
*
4
,
world_size
,
model
=
FMoETransformerMLP
(
num_expert
,
d_hidden
,
d_hidden
*
4
,
world_size
,
gate
=
ConstantGate
).
cuda
()
gate
=
ConstantGate
).
cuda
()
...
@@ -57,6 +59,9 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
...
@@ -57,6 +59,9 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
args
=
json
.
loads
(
sys
.
argv
[
2
])
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
args
[
'world_size'
]
=
torch
.
distributed
.
get_world_size
()
args
[
'world_size'
]
=
torch
.
distributed
.
get_world_size
()
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment