Unverified Commit 2d1a7dfe authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[zero] add strict ddp mode (#2508)

* [zero] add strict ddp mode

* [polish] add comments for strict ddp mode

* [zero] fix test error
parent c04f1832
...@@ -12,6 +12,7 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator ...@@ -12,6 +12,7 @@ from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored from colossalai.utils import get_current_device, is_ddp_ignored
...@@ -200,14 +201,18 @@ class ZeroDDP(ColoDDP): ...@@ -200,14 +201,18 @@ class ZeroDDP(ColoDDP):
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``. For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory. pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
""" """
def __init__(self, def __init__(self,
module: torch.nn.Module, module: torch.nn.Module,
gemini_manager: GeminiManager, gemini_manager: GeminiManager,
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False) -> None: force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup()) super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
...@@ -232,6 +237,9 @@ class ZeroDDP(ColoDDP): ...@@ -232,6 +237,9 @@ class ZeroDDP(ColoDDP):
for p in param_order.generate(): for p in param_order.generate():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
if strict_ddp_mode and not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
if is_ddp_ignored(p): if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16) p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue continue
......
...@@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP): ...@@ -17,6 +17,7 @@ class GeminiDDP(ZeroDDP):
placement_policy: str = "cpu", placement_policy: str = "cpu",
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32, search_range_mb: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None, min_chunk_size_mb: Optional[float] = None,
...@@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP): ...@@ -54,4 +55,4 @@ class GeminiDDP(ZeroDDP):
search_range_mb=search_range_mb, search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb) min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
...@@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True): ...@@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_30b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_40b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
def model_builder(model_size: str) -> callable: def model_builder(model_size: str) -> callable:
if model_size == "gpt2_medium": if model_size == "gpt2_medium":
return gpt2_medium return gpt2_medium
...@@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable: ...@@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
return gpt2_20b return gpt2_20b
elif model_size == "gpt2_24b": elif model_size == "gpt2_24b":
return gpt2_24b return gpt2_24b
elif model_size == "gpt2_30b":
return gpt2_30b
elif model_size == "gpt2_40b":
return gpt2_40b
else: else:
raise TypeError(f"model_builder {model_size}") raise TypeError(f"model_builder {model_size}")
......
...@@ -187,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): ...@@ -187,17 +187,18 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto", ddp_flag: bool = True):
fp16_init_scale = 2**5 fp16_init_scale = 2**5
gpu_margin_mem_ratio_for_auto = 0 gpu_margin_mem_ratio_for_auto = 0
if version.parse(CAI_VERSION) > version.parse("0.1.10"): if version.parse(CAI_VERSION) > version.parse("0.1.10"):
model = GeminiDDP(model, model = GeminiDDP(model,
strict_ddp_mode=ddp_flag,
device=get_current_device(), device=get_current_device(),
placement_policy=placement_policy, placement_policy=placement_policy,
pin_memory=True, pin_memory=True,
hidden_dim=model.config.n_embd, hidden_dim=model.config.n_embd,
search_range_mb=64) search_range_mb=128)
# configure the const policy # configure the const policy
if placement_policy == 'const': if placement_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024) model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
...@@ -279,11 +280,12 @@ def main(): ...@@ -279,11 +280,12 @@ def main():
tp_pg = ProcessGroup(tp_degree=args.tp_degree) tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP) # Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1 # You should notice that v0.1.10 is not compatible with TP degree > 1
tensor_parallelize(model, tp_pg) if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer # build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP # Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, tp_pg, args.placement) model, optimizer = build_gemini(model, tp_pg, args.placement, args.tp_degree == 1)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: else:
......
...@@ -93,7 +93,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None): ...@@ -93,7 +93,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
else: else:
init_device = None init_device = None
model = GeminiDDP(model, init_device, placement_policy, True, False, 32) model = GeminiDDP(model, init_device, placement_policy, True, False)
# The same as the following 3 lines # The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device) # chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager) # gemini_manager = GeminiManager(placement_policy, chunk_manager)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment