Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
06db94fb
Commit
06db94fb
authored
Feb 08, 2024
by
ver217
Browse files
[moe] fix tests
parent
65e5d6ba
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
39 deletions
+31
-39
colossalai/moe/routers.py
colossalai/moe/routers.py
+1
-1
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+6
-3
tests/test_moe/test_moe_checkpoint.py
tests/test_moe/test_moe_checkpoint.py
+5
-22
tests/test_moe/test_moe_router.py
tests/test_moe/test_moe_router.py
+19
-13
No files found.
colossalai/moe/routers.py
View file @
06db94fb
...
@@ -47,7 +47,7 @@ class MoeRouter(nn.Module, ABC):
...
@@ -47,7 +47,7 @@ class MoeRouter(nn.Module, ABC):
def
get_capacity
(
self
,
num_tokens
,
num_experts
,
ep_group
=
None
):
def
get_capacity
(
self
,
num_tokens
,
num_experts
,
ep_group
=
None
):
if
ep_group
is
not
None
:
if
ep_group
is
not
None
:
num_tokens_tensor
=
torch
.
tensor
(
num_tokens
,
device
=
get_current_device
())
num_tokens_tensor
=
torch
.
tensor
(
num_tokens
,
device
=
get_
accelerator
().
get_
current_device
())
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
ep_group
)
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
ep_group
)
num_tokens
=
num_tokens_tensor
.
item
()
//
dist
.
get_world_size
(
ep_group
)
num_tokens
=
num_tokens_tensor
.
item
()
//
dist
.
get_world_size
(
ep_group
)
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
capacity_factor
=
self
.
capacity_factor_train
if
self
.
training
else
self
.
capacity_factor_eval
...
...
colossalai/zero/low_level/low_level_optim.py
View file @
06db94fb
...
@@ -911,11 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -911,11 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param
.
copy_
(
working_param
.
chunk
(
self
.
extra_dp_pg_size
)[
self
.
extra_dp_pg_rank
])
master_param
.
copy_
(
working_param
.
chunk
(
self
.
extra_dp_pg_size
)[
self
.
extra_dp_pg_rank
])
else
:
else
:
master_param
.
copy_
(
working_param
.
chunk
(
self
.
_world_size
)[
self
.
_local_rank
])
master_param
.
copy_
(
working_param
.
chunk
(
self
.
_world_size
)[
self
.
_local_rank
])
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
if
hasattr
(
self
,
"master_moe_params"
):
master_moe_param
.
copy_
(
working_moe_param
)
for
master_moe_param
,
working_moe_param
in
zip
(
self
.
master_moe_params
,
self
.
working_moe_params
):
master_moe_param
.
copy_
(
working_moe_param
)
def
get_working_to_master_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
def
get_working_to_master_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
return
self
.
_param_store
.
working_to_master_param
return
self
.
_param_store
.
working_to_master_param
def
get_master_to_working_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
def
get_master_to_working_map
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
return
{
**
self
.
_param_store
.
master_to_working_param
,
**
self
.
moe_master_to_working_map
}
if
hasattr
(
self
,
"moe_master_to_working_map"
):
return
{
**
self
.
_param_store
.
master_to_working_param
,
**
self
.
moe_master_to_working_map
}
return
self
.
_param_store
.
master_to_working_param
tests/test_moe/test_moe_checkpoint.py
View file @
06db94fb
...
@@ -12,7 +12,6 @@ import colossalai
...
@@ -12,7 +12,6 @@ import colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.testing
import
DummyDataloader
,
check_state_dict_equal
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
DummyDataloader
,
check_state_dict_equal
,
rerun_if_address_is_in_use
,
spawn
sys
.
path
.
append
(
sys
.
path
.
append
(
...
@@ -95,6 +94,7 @@ def get_model(parallel):
...
@@ -95,6 +94,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
1
,
zero_stage
=
2
,
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
)
...
@@ -103,6 +103,7 @@ def get_model(parallel):
...
@@ -103,6 +103,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
dist
.
get_world_size
(),
zero_stage
=
2
,
zero_stage
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
)
)
...
@@ -111,6 +112,7 @@ def get_model(parallel):
...
@@ -111,6 +112,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
1
,
pp_size
=
1
,
ep_size
=
2
,
zero_stage
=
2
,
zero_stage
=
2
,
extra_dp_size
=
2
,
extra_dp_size
=
2
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
@@ -120,6 +122,7 @@ def get_model(parallel):
...
@@ -120,6 +122,7 @@ def get_model(parallel):
precision
=
"bf16"
,
precision
=
"bf16"
,
tp_size
=
1
,
tp_size
=
1
,
pp_size
=
2
,
pp_size
=
2
,
ep_size
=
2
,
zero_stage
=
1
,
zero_stage
=
1
,
microbatch_size
=
1
,
microbatch_size
=
1
,
custom_policy
=
OpenMoeForCausalLMPolicy
(),
custom_policy
=
OpenMoeForCausalLMPolicy
(),
...
@@ -130,27 +133,6 @@ def get_model(parallel):
...
@@ -130,27 +133,6 @@ def get_model(parallel):
def
_test_moe_checkpoint
(
rank
,
parallel
):
def
_test_moe_checkpoint
(
rank
,
parallel
):
if
parallel
==
None
:
MOE_MANAGER
.
setup
(
parallel
=
None
,
)
elif
parallel
==
"ep"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
)
elif
parallel
==
"ep_zero"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
max_ep_size
=
2
,
)
elif
parallel
==
"hybrid"
:
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
1
,
fixed_ep_size
=
2
,
fixed_pp_size
=
2
,
)
model1
,
booster1
,
optim1
=
get_model
(
parallel
)
model1
,
booster1
,
optim1
=
get_model
(
parallel
)
model2
,
booster2
,
optim2
=
get_model
(
parallel
)
model2
,
booster2
,
optim2
=
get_model
(
parallel
)
model3
,
booster3
,
optim3
=
get_model
(
parallel
)
model3
,
booster3
,
optim3
=
get_model
(
parallel
)
...
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
...
@@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint
(
rank
,
parallel
)
_test_moe_checkpoint
(
rank
,
parallel
)
@
pytest
.
mark
.
skip
(
reason
=
"This is tested in ColossalMOE"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"parallel"
,
[
None
,
"ep"
,
"ep_zero"
,
"hybrid"
])
@
pytest
.
mark
.
parametrize
(
"parallel"
,
[
None
,
"ep"
,
"ep_zero"
,
"hybrid"
])
...
...
tests/test_moe/test_moe_router.py
View file @
06db94fb
...
@@ -4,15 +4,21 @@ import torch
...
@@ -4,15 +4,21 @@ import torch
from
colossalai.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
from
colossalai.moe.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
@
pytest
.
mark
.
parametrize
([
"router"
,
"num_groups"
],
[
@
pytest
.
mark
.
parametrize
(
(
Top1Router
(),
1
),
[
"router"
,
"num_groups"
],
(
Top2Router
(),
1
),
[
# (TopKRouter(num_selected_experts=3), 4),
(
Top1Router
(),
1
),
])
(
Top2Router
(),
1
),
@
pytest
.
mark
.
parametrize
([
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
# (TopKRouter(num_selected_experts=3), 4),
(
4
,
5
,
8
),
],
(
3
,
4
,
4
),
)
])
@
pytest
.
mark
.
parametrize
(
[
"batch_size"
,
"seq_len"
,
"num_experts"
],
[
(
4
,
5
,
8
),
(
3
,
4
,
4
),
],
)
def
test_router_forward
(
router
:
MoeRouter
,
batch_size
:
int
,
seq_len
:
int
,
num_experts
:
int
,
num_groups
:
int
):
def
test_router_forward
(
router
:
MoeRouter
,
batch_size
:
int
,
seq_len
:
int
,
num_experts
:
int
,
num_groups
:
int
):
x
=
torch
.
randn
((
batch_size
*
seq_len
,
num_experts
)).
cuda
()
x
=
torch
.
randn
((
batch_size
*
seq_len
,
num_experts
)).
cuda
()
if
num_groups
>
1
:
if
num_groups
>
1
:
...
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
...
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router
.
train
()
router
.
train
()
if
isinstance
(
router
,
TopKRouter
):
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
router
.
eval
()
router
.
eval
()
if
isinstance
(
router
,
TopKRouter
):
if
isinstance
(
router
,
TopKRouter
):
_
,
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
combine_array
,
dispatch_mask
=
router
(
x
,
expert_capacity
=
2
)
else
:
else
:
_
,
combine_array
,
dispatch_mask
=
router
(
x
)
combine_array
,
dispatch_mask
=
router
(
x
)
[
1
:
3
]
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
combine_array
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
dispatch_mask
.
shape
[:
-
1
]
==
x
.
shape
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
assert
torch
.
all
(
dispatch_mask
.
sum
(
-
1
).
sum
(
-
1
)
<=
router
.
k_value
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment