Unverified Commit 453511ac authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Save memory for expert model parallel (#9957)

parent d0730487
...@@ -1458,10 +1458,15 @@ def initialize_model_parallel( ...@@ -1458,10 +1458,15 @@ def initialize_model_parallel(
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size moe_tp_size = tensor_model_parallel_size // moe_ep_size
global _MOE_EP global _MOE_EP
assert _MOE_EP is None, "expert model parallel group is already initialized" assert _MOE_EP is None, "expert model parallel group is already initialized"
if moe_ep_size == tensor_model_parallel_size:
_MOE_EP = _TP
else:
# TODO(ch-wan): use split_group to save memory
group_ranks = [] group_ranks = []
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size): for j in range(moe_tp_size):
...@@ -1469,17 +1474,20 @@ def initialize_model_parallel( ...@@ -1469,17 +1474,20 @@ def initialize_model_parallel(
en = (i + 1) * tensor_model_parallel_size + j en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size)) ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks) group_ranks.append(ranks)
_MOE_EP = init_model_parallel_group( _MOE_EP = init_model_parallel_group(
group_ranks, group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False,
group_name="moe_ep", group_name="moe_ep",
) )
global _MOE_TP global _MOE_TP
assert _MOE_TP is None, "expert model parallel group is already initialized" assert _MOE_TP is None, "expert model parallel group is already initialized"
if moe_tp_size == tensor_model_parallel_size:
_MOE_TP = _TP
else:
# TODO(ch-wan): use split_group to save memory
group_ranks = [] group_ranks = []
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
for j in range(moe_ep_size): for j in range(moe_ep_size):
...@@ -1487,12 +1495,10 @@ def initialize_model_parallel( ...@@ -1487,12 +1495,10 @@ def initialize_model_parallel(
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en)) ranks = list(range(st, en))
group_ranks.append(ranks) group_ranks.append(ranks)
_MOE_TP = init_model_parallel_group( _MOE_TP = init_model_parallel_group(
group_ranks, group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False,
group_name="moe_tp", group_name="moe_tp",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment