Unverified Commit 453511ac authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Save memory for expert model parallel (#9957)

parent d0730487
...@@ -1458,43 +1458,49 @@ def initialize_model_parallel( ...@@ -1458,43 +1458,49 @@ def initialize_model_parallel(
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size moe_tp_size = tensor_model_parallel_size // moe_ep_size
global _MOE_EP global _MOE_EP
assert _MOE_EP is None, "expert model parallel group is already initialized" assert _MOE_EP is None, "expert model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size):
st = i * tensor_model_parallel_size + j
en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks)
_MOE_EP = init_model_parallel_group( if moe_ep_size == tensor_model_parallel_size:
group_ranks, _MOE_EP = _TP
get_world_group().local_rank, else:
backend, # TODO(ch-wan): use split_group to save memory
use_custom_allreduce=False, group_ranks = []
group_name="moe_ep", for i in range(num_tensor_model_parallel_groups):
) for j in range(moe_tp_size):
st = i * tensor_model_parallel_size + j
en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks)
_MOE_EP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="moe_ep",
)
global _MOE_TP global _MOE_TP
assert _MOE_TP is None, "expert model parallel group is already initialized" assert _MOE_TP is None, "expert model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_ep_size):
st = i * tensor_model_parallel_size + j * moe_tp_size
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en))
group_ranks.append(ranks)
_MOE_TP = init_model_parallel_group( if moe_tp_size == tensor_model_parallel_size:
group_ranks, _MOE_TP = _TP
get_world_group().local_rank, else:
backend, # TODO(ch-wan): use split_group to save memory
use_custom_allreduce=False, group_ranks = []
group_name="moe_tp", for i in range(num_tensor_model_parallel_groups):
) for j in range(moe_ep_size):
st = i * tensor_model_parallel_size + j * moe_tp_size
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en))
group_ranks.append(ranks)
_MOE_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="moe_tp",
)
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment