Unverified Commit 1aaeb596 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[example] gpt, shard init on all processes (#2366)

parent 1f8ab6f1
...@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor): ...@@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
def set_process_group(self, pg: ProcessGroup): def set_process_group(self, pg: ProcessGroup):
"""set_process_group """set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited. change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid. It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args: Args:
pg (ProcessGroup): target pg pg (ProcessGroup): target pg
...@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor): ...@@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns # if the new pg is the same as the old pg, just returns
if self.process_group == pg: if self.process_group == pg:
return return
assert self.process_group.tp_world_size() == 1, \ assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group has tp world group" "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \ assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE" "Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg self.process_group = pg
......
...@@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): ...@@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
""" """
for mn, module in model.named_modules(): for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False): for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by tow modules # NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'): if hasattr(param, 'visited'):
continue continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn: if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn: if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
...@@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): ...@@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
else: else:
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
param.visited = True param.visited = True
...@@ -248,27 +253,28 @@ def main(): ...@@ -248,27 +253,28 @@ def main():
torch.manual_seed(123) torch.manual_seed(123)
if args.distplan == "colossalai": if args.distplan == "colossalai":
# all param must use the same process group. # all param must use the same process group.
default_pg = ProcessGroup(tp_degree=args.tp_degree) world_size = torch.distributed.get_world_size()
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build GPT model # build GPT model
if version.parse(CAI_VERSION) > version.parse("0.1.10"): if version.parse(CAI_VERSION) > version.parse("0.1.10"):
with ColoInitContext(device=get_current_device(), with ColoInitContext(device=get_current_device(),
dtype=torch.half, dtype=torch.half,
default_dist_spec=default_dist_spec, default_dist_spec=default_dist_spec,
default_pg=default_pg): default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True) model = model_builder(args.model_type)(checkpoint=True)
else: else:
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True) model = model_builder(args.model_type)(checkpoint=True)
pg = default_pg tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP) # Tensor Parallelism (TP)
tensor_parallelize(model, pg) tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer # build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP # Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, pg, args.placement) model, optimizer = build_gemini(model, tp_pg, args.placement)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment