Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
0a6becae
Commit
0a6becae
authored
Sep 11, 2023
by
Rick Ho
Browse files
fix tests
parent
945004e7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
39 deletions
+96
-39
tests/moe.py
tests/moe.py
+2
-2
tests/test.sh
tests/test.sh
+1
-1
tests/test_ddp.py
tests/test_ddp.py
+35
-7
tests/test_faster_schedule.py
tests/test_faster_schedule.py
+26
-15
tests/test_faster_shadow.py
tests/test_faster_shadow.py
+10
-9
tests/test_numerical.py
tests/test_numerical.py
+22
-5
No files found.
tests/moe.py
View file @
0a6becae
...
@@ -78,7 +78,7 @@ class BruteForceMoE(nn.Module):
...
@@ -78,7 +78,7 @@ class BruteForceMoE(nn.Module):
class
NaiveExpert
(
nn
.
Module
):
class
NaiveExpert
(
nn
.
Module
):
def
__init__
(
self
,
d_model
):
def
__init__
(
self
,
d_model
):
super
(
NaiveExpert
,
self
).
__init__
()
super
(
NaiveExpert
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
d_model
,
d_model
)
.
cuda
()
self
.
linear
=
nn
.
Linear
(
d_model
,
d_model
)
def
forward
(
self
,
x
,
fec
=
None
):
def
forward
(
self
,
x
,
fec
=
None
):
return
self
.
linear
(
x
)
return
self
.
linear
(
x
)
...
@@ -89,7 +89,7 @@ class LinearExpert(nn.Module):
...
@@ -89,7 +89,7 @@ class LinearExpert(nn.Module):
super
(
LinearExpert
,
self
).
__init__
()
super
(
LinearExpert
,
self
).
__init__
()
self
.
model
=
nn
.
Sequential
(
self
.
model
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_model
*
2
),
nn
.
ReLU
(),
nn
.
Linear
(
d_model
*
2
,
d_model
),
nn
.
Linear
(
d_model
,
d_model
*
2
),
nn
.
ReLU
(),
nn
.
Linear
(
d_model
*
2
,
d_model
),
)
.
cuda
()
)
def
forward
(
self
,
x
,
fec
=
None
):
def
forward
(
self
,
x
,
fec
=
None
):
return
self
.
model
(
x
)
return
self
.
model
(
x
)
tests/test.sh
View file @
0a6becae
...
@@ -30,4 +30,4 @@ fi
...
@@ -30,4 +30,4 @@ fi
export
CUDA_VISIBLE_DEVICES
=
$localrank
export
CUDA_VISIBLE_DEVICES
=
$localrank
exec
$@
exec
$@
2>&1 |
tee
$RANK
.log
tests/test_ddp.py
View file @
0a6becae
...
@@ -4,6 +4,7 @@ import os
...
@@ -4,6 +4,7 @@ import os
import
sys
import
sys
from
typing
import
Dict
from
typing
import
Dict
import
random
import
random
import
socket
as
sock
import
pytest
import
pytest
import
torch
import
torch
...
@@ -24,6 +25,8 @@ def _ensure_initialized():
...
@@ -24,6 +25,8 @@ def _ensure_initialized():
dist
.
init_process_group
(
backend
=
"nccl"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
port_count
=
0
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
,
env
=
dict
()):
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
,
env
=
dict
()):
device_count
=
torch
.
cuda
.
device_count
()
device_count
=
torch
.
cuda
.
device_count
()
if
device_count
<
world_size
:
if
device_count
<
world_size
:
...
@@ -33,7 +36,9 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
...
@@ -33,7 +36,9 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
ps
=
[]
ps
=
[]
env
[
"MASTER_ADDR"
]
=
"localhost"
env
[
"MASTER_ADDR"
]
=
"localhost"
env
[
"MASTER_PORT"
]
=
str
(
random
.
randint
(
50000
,
60000
))
global
port_count
env
[
"MASTER_PORT"
]
=
str
(
9010
+
port_count
)
port_count
+=
1
env
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
env
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
env
[
"LD_LIBRARY_PATH"
]
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
env
[
"LD_LIBRARY_PATH"
]
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
...
@@ -58,7 +63,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
...
@@ -58,7 +63,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.
F
loat
Tensor
'
,
'torch.
DoubleTensor'
,
'torch.HalfTensor
'
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.
f
loat
32
'
,
'torch.
bfloat16'
,
'torch.float16
'
])
def
test_fmoe_linear_distributed
(
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
,
data_type
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
,
data_type
):
):
...
@@ -83,7 +88,8 @@ def test_fmoe_linear_distributed(
...
@@ -83,7 +88,8 @@ def test_fmoe_linear_distributed(
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
def
test_fmoe_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
expert
,
mp_size
):
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.float32'
,
'torch.bfloat16'
,
'torch.float16'
])
def
test_fmoe_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
expert
,
mp_size
,
data_type
):
_run_distributed
(
_run_distributed
(
"_test_fmoe"
,
"_test_fmoe"
,
mp_size
*
2
,
mp_size
*
2
,
...
@@ -94,6 +100,7 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
...
@@ -94,6 +100,7 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
"d_model"
:
d_model
,
"d_model"
:
d_model
,
"expert"
:
expert
,
"expert"
:
expert
,
"mp_size"
:
mp_size
,
"mp_size"
:
mp_size
,
"data_type"
:
data_type
,
},
},
)
)
...
@@ -137,8 +144,29 @@ if __name__ == "__main__":
...
@@ -137,8 +144,29 @@ if __name__ == "__main__":
del
args
[
"mp_size"
]
del
args
[
"mp_size"
]
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
test_fmoe_local_ddp
(
mp_size
=
2
)
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
test_fmoe_linear_distributed
(
args
=
dict
(
mp_size
=
1
,
data_type
=
'torch.float16'
)
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
,
args
[
"rank"
]
=
torch
.
distributed
.
get_rank
()
data_type
=
"torch.HalfTensor"
args
[
"world_size"
]
=
torch
.
distributed
.
get_world_size
()
args
[
"mp_group"
]
=
[
torch
.
distributed
.
new_group
(
ranks
=
[
j
*
args
[
"mp_size"
]
+
i
for
i
in
range
(
args
[
"mp_size"
])],
backend
=
"nccl"
,
)
for
j
in
range
(
args
[
"world_size"
]
//
args
[
"mp_size"
])
][
args
[
"rank"
]
//
args
[
"mp_size"
]]
args
[
"dp_group"
]
=
[
torch
.
distributed
.
new_group
(
ranks
=
[
i
*
args
[
"mp_size"
]
+
j
for
i
in
range
(
args
[
"world_size"
]
//
args
[
"mp_size"
])
],
backend
=
"nccl"
,
)
)
for
j
in
range
(
args
[
"mp_size"
])
][
args
[
"rank"
]
%
args
[
"mp_size"
]]
args
[
"world_group"
]
=
torch
.
distributed
.
new_group
(
ranks
=
list
(
range
(
args
[
"world_size"
])),
backend
=
"nccl"
,
)
del
args
[
"mp_size"
]
_test_fmoe
(
4
,
2
,
16
,
2
,
'NaiveExpert'
,
**
args
)
tests/test_faster_schedule.py
View file @
0a6becae
...
@@ -18,7 +18,7 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
...
@@ -18,7 +18,7 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
def
test_faster_schedule
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
):
def
test_faster_schedule
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
):
_run_distributed
(
'_test_faster_schedule'
,
_run_distributed
(
'_test_faster_schedule'
,
...
@@ -45,28 +45,39 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
...
@@ -45,28 +45,39 @@ def _test_faster_schedule(d_model, batch_size, n_expert):
x2
=
x1
.
data
.
clone
()
x2
=
x1
.
data
.
clone
()
x2
.
requires_grad
=
True
x2
.
requires_grad
=
True
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
topk_idx
=
torch
.
randint
(
0
,
world_size
*
n_expert
,
(
batch_size
,
2
)).
cuda
()
m1
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
m1
s
=
[
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
for
_
in
range
(
n_expert
)]
m2
=
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
m2
s
=
[
torch
.
nn
.
Linear
(
d_model
,
d_model
).
cuda
()
for
_
in
range
(
n_expert
)]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
m1
,
m2
in
zip
(
m1s
,
m2s
):
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
bias
.
copy_
(
m1
.
bias
)
m2
.
bias
.
copy_
(
m1
.
bias
)
def
ef1
(
x
,
fec
):
def
ef1
(
x
,
fec
,
eidx
):
y
=
m1
(
x
)
return
m1s
[
eidx
]
(
x
)
return
y
def
ef2
(
x
,
fec
):
def
ef2
(
x
,
fec
):
y
=
m2
(
x
)
o
=
0
ys
=
[]
for
m
,
i
in
zip
(
m2s
,
fec
):
if
i
>
0
:
ys
.
append
(
m
(
x
[
o
:
o
+
i
]))
o
+=
i
y
=
torch
.
cat
(
ys
)
return
y
return
y
ensure_comm
(
x1
,
None
)
ensure_comm
(
x1
,
None
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1s
)
y1
.
sum
().
backward
()
y1
.
sum
().
backward
()
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
)
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2s
)
y2
.
sum
().
backward
()
y2
.
sum
().
backward
()
_assert_numerical
([
'out'
,
'grad_in'
,
'grad_bias'
,
'grad_weight'
],
_assert_numerical
([
'out'
,
'grad_in'
],
[
y1
,
x1
.
grad
,
m1
.
bias
.
grad
,
m1
.
weight
.
grad
],
[
y1
,
x1
.
grad
],
[
y2
,
x2
.
grad
,
m2
.
bias
.
grad
,
m2
.
weight
.
grad
],
rank
)
[
y2
,
x2
.
grad
],
rank
)
for
i
in
range
(
n_expert
):
_assert_numerical
([
f
'grad_bias_
{
i
}
'
,
f
'grad_weight_
{
i
}
'
],
[
m1s
[
i
].
bias
.
grad
,
m1s
[
i
].
weight
.
grad
],
[
m2s
[
i
].
bias
.
grad
,
m2s
[
i
].
weight
.
grad
],
rank
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -75,4 +86,4 @@ if __name__ == '__main__':
...
@@ -75,4 +86,4 @@ if __name__ == '__main__':
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
# test_faster_schedule(8, 16, 16, 1, 2)
# test_faster_schedule(8, 16, 16, 1, 2)
_test_faster_schedule
(
4
,
2
,
1
)
_test_faster_schedule
(
4
,
2
,
4
)
tests/test_faster_shadow.py
View file @
0a6becae
...
@@ -20,7 +20,7 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
...
@@ -20,7 +20,7 @@ from fmoe.layers import _fmoe_general_global_forward as naive_fwd
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
,
512
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"group_sz"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"pass_stored"
,
[
False
,
Tru
e
])
@
pytest
.
mark
.
parametrize
(
"pass_stored"
,
[
True
,
Fals
e
])
def
test_faster_shadow
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
,
pass_stored
):
def
test_faster_shadow
(
n_process
,
d_model
,
batch_size
,
n_expert
,
group_sz
,
pass_stored
):
_run_distributed
(
'_test_faster_shadow'
,
_run_distributed
(
'_test_faster_shadow'
,
n_process
,
n_process
,
...
@@ -54,7 +54,7 @@ def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
...
@@ -54,7 +54,7 @@ def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
weight
.
copy_
(
m1
.
weight
)
m2
.
bias
.
copy_
(
m1
.
bias
)
m2
.
bias
.
copy_
(
m1
.
bias
)
def
ef1
(
x
,
fec
):
def
ef1
(
x
,
fec
,
eidx
):
y
=
m1
(
x
)
y
=
m1
(
x
)
return
y
return
y
def
ef2
(
x
,
fec
):
def
ef2
(
x
,
fec
):
...
@@ -62,22 +62,23 @@ def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
...
@@ -62,22 +62,23 @@ def _test_faster_shadow(d_model, batch_size, n_expert, pass_stored):
return
y
return
y
if
pass_stored
:
if
pass_stored
:
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
,)).
bool
().
cuda
()
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
*
n_expert
,)).
bool
().
cuda
()
while
stored_models
.
sum
().
item
()
==
0
:
stored_models
=
torch
.
randint
(
0
,
2
,
(
world_size
*
n_expert
,)).
bool
().
cuda
()
stored_models
[
-
1
]
=
True
dist
.
broadcast
(
stored_models
,
0
)
dist
.
broadcast
(
stored_models
,
0
)
stored_models
=
stored_models
.
cpu
()
stored_models
=
stored_models
.
cpu
()
print
(
stored_models
)
# if rank == 0:
# print('stored models {}'.format(stored_models))
ensure_comm
(
x1
,
None
)
ensure_comm
(
x1
,
None
)
if
pass_stored
:
if
pass_stored
:
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
,
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
[
m1
]
,
stored_models
=
stored_models
)
stored_models
=
stored_models
)
else
:
else
:
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
m1
)
y1
=
smart_fwd
(
x1
,
topk_idx
,
ef1
,
n_expert
,
world_size
,
experts
=
[
m1
]
)
y1
.
sum
().
backward
()
y1
.
sum
().
backward
()
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
m2
)
y2
=
naive_fwd
(
x2
,
topk_idx
,
ef2
,
n_expert
,
world_size
,
experts
=
[
m2
]
)
y2
.
sum
().
backward
()
y2
.
sum
().
backward
()
_assert_numerical
([
'out'
,
'grad_in'
,
'grad_bias'
,
'grad_weight'
],
_assert_numerical
([
'out'
,
'grad_in'
,
'grad_bias'
,
'grad_weight'
],
[
y1
,
x1
.
grad
,
m1
.
bias
.
grad
,
m1
.
weight
.
grad
],
[
y1
,
x1
.
grad
,
m1
.
bias
.
grad
,
m1
.
weight
.
grad
],
...
...
tests/test_numerical.py
View file @
0a6becae
...
@@ -50,8 +50,12 @@ def _perform_forward(
...
@@ -50,8 +50,12 @@ def _perform_forward(
def
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
def
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
if
mo
is
None
and
ro
is
None
:
continue
if
mo
is
None
or
ro
is
None
:
assert
False
err
=
(
mo
-
ro
).
abs
().
max
()
err
=
(
mo
-
ro
).
abs
().
max
()
if
err
.
dtype
==
torch
.
bfloat16
:
if
err
.
dtype
==
torch
.
bfloat16
or
err
.
dtype
==
torch
.
float16
:
precision
*=
100
precision
*=
100
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
precision
:
if
err
>
precision
:
...
@@ -93,7 +97,7 @@ class MyMoE(FMoE):
...
@@ -93,7 +97,7 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.
F
loat
Tensor
'
,
'torch.
DoubleTensor'
,
'torch.HalfTensor
'
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.
f
loat
32
'
,
'torch.
bfloat16'
,
'torch.float16
'
])
def
test_fmoe_linear
(
def
test_fmoe_linear
(
num_expert
,
num_expert
,
top_k
,
top_k
,
...
@@ -111,6 +115,9 @@ def test_fmoe_linear(
...
@@ -111,6 +115,9 @@ def test_fmoe_linear(
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
if
isinstance
(
data_type
,
str
):
data_type
=
eval
(
data_type
)
moe
=
MyMoE
(
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
type
(
data_type
).
cuda
()
).
type
(
data_type
).
cuda
()
...
@@ -238,6 +245,9 @@ def test_fmoe(
...
@@ -238,6 +245,9 @@ def test_fmoe(
if
isinstance
(
expert
,
str
):
if
isinstance
(
expert
,
str
):
expert
=
globals
()[
expert
]
expert
=
globals
()[
expert
]
assert
(
expert
is
not
None
)
if
isinstance
(
data_type
,
str
):
data_type
=
eval
(
data_type
)
moe
=
FMoE
(
moe
=
FMoE
(
num_expert
=
num_expert
,
num_expert
=
num_expert
,
...
@@ -247,7 +257,7 @@ def test_fmoe(
...
@@ -247,7 +257,7 @@ def test_fmoe(
mp_group
=
mp_group
,
mp_group
=
mp_group
,
expert
=
expert
,
expert
=
expert
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
().
t
o
(
data_type
)
).
cuda
().
t
ype
(
data_type
)
moe_raw
=
BruteForceMoE
(
moe_raw
=
BruteForceMoE
(
expert
=
expert
,
expert
=
expert
,
...
@@ -255,7 +265,7 @@ def test_fmoe(
...
@@ -255,7 +265,7 @@ def test_fmoe(
d_model
=
d_model
,
d_model
=
d_model
,
world_size
=
world_size
,
world_size
=
world_size
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
().
t
o
(
data_type
)
).
cuda
().
t
ype
(
data_type
)
if
world_size
==
1
:
if
world_size
==
1
:
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
...
@@ -266,9 +276,11 @@ def test_fmoe(
...
@@ -266,9 +276,11 @@ def test_fmoe(
else
:
else
:
assert
len
(
moe
.
experts
)
>=
1
assert
len
(
moe
.
experts
)
>=
1
for
idx
,
para
in
enumerate
(
moe
.
experts
[
0
].
parameters
()):
for
idx
,
para
in
enumerate
(
moe
.
experts
[
0
].
parameters
()):
assert
(
para
.
device
.
type
==
'cuda'
)
para_tensor
=
torch
.
cat
(
para_tensor
=
torch
.
cat
(
[
list
(
expert
.
parameters
())[
idx
].
unsqueeze
(
0
)
for
expert
in
moe
.
experts
]
[
list
(
expert
.
parameters
())[
idx
].
unsqueeze
(
0
)
for
expert
in
moe
.
experts
]
)
)
assert
(
para_tensor
.
device
.
type
==
'cuda'
)
para_array
=
[
torch
.
empty_like
(
para_tensor
)
for
_
in
range
(
world_size
)]
para_array
=
[
torch
.
empty_like
(
para_tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
para_array
,
para_tensor
)
torch
.
distributed
.
all_gather
(
para_array
,
para_tensor
)
para_tensor_gathered
=
torch
.
cat
(
para_array
,
dim
=
0
)
para_tensor_gathered
=
torch
.
cat
(
para_array
,
dim
=
0
)
...
@@ -419,6 +431,8 @@ def test_fmoe_experts(
...
@@ -419,6 +431,8 @@ def test_fmoe_experts(
if
isinstance
(
expert
,
str
):
if
isinstance
(
expert
,
str
):
expert
=
globals
()[
expert
]
expert
=
globals
()[
expert
]
if
isinstance
(
data_type
,
str
):
data_type
=
eval
(
data_type
)
moe
=
FMoE
(
moe
=
FMoE
(
num_expert
=
num_expert
,
num_expert
=
num_expert
,
...
@@ -428,7 +442,7 @@ def test_fmoe_experts(
...
@@ -428,7 +442,7 @@ def test_fmoe_experts(
mp_group
=
mp_group
,
mp_group
=
mp_group
,
expert
=
expert
,
expert
=
expert
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
().
t
o
(
data_type
)
).
cuda
().
t
ype
(
data_type
)
moe_raw
=
BruteForceMoE
(
moe_raw
=
BruteForceMoE
(
expert
=
expert
,
expert
=
expert
,
...
@@ -447,9 +461,12 @@ def test_fmoe_experts(
...
@@ -447,9 +461,12 @@ def test_fmoe_experts(
else
:
else
:
assert
len
(
moe
.
experts
)
>=
1
assert
len
(
moe
.
experts
)
>=
1
for
idx
,
para
in
enumerate
(
moe
.
experts
[
0
].
parameters
()):
for
idx
,
para
in
enumerate
(
moe
.
experts
[
0
].
parameters
()):
for
ep
in
expert
.
parameters
():
assert
(
ep
.
device
.
type
==
'cuda'
)
para_tensor
=
torch
.
cat
(
para_tensor
=
torch
.
cat
(
[
list
(
expert
.
parameters
())[
idx
].
unsqueeze
(
0
)
for
expert
in
moe
.
experts
]
[
list
(
expert
.
parameters
())[
idx
].
unsqueeze
(
0
)
for
expert
in
moe
.
experts
]
)
)
assert
(
para_tensor
.
device
.
type
==
'cuda'
)
para_array
=
[
torch
.
empty_like
(
para_tensor
)
for
_
in
range
(
world_size
)]
para_array
=
[
torch
.
empty_like
(
para_tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
para_array
,
para_tensor
)
torch
.
distributed
.
all_gather
(
para_array
,
para_tensor
)
para_tensor_gathered
=
torch
.
cat
(
para_array
,
dim
=
0
)
para_tensor_gathered
=
torch
.
cat
(
para_array
,
dim
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment