Unverified Commit 1f992058 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[Gemini] remove static tracer (#2083)

parent 28ef3f29
......@@ -26,27 +26,13 @@ class GeminiManager:
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
"""
def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
module: Optional[torch.nn.Module] = None,
use_static_memstats: bool = False) -> None:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
# self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self.use_static_memstats = use_static_memstats
if policy_cls.need_mem_stats:
if use_static_memstats:
assert module is not None
self._mem_stats_collector = StaticMemStatsCollector(module, chunk_manager)
else:
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager)
else:
self._mem_stats_collector = None
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
......@@ -60,11 +46,7 @@ class GeminiManager:
def pre_iter(self, *args):
if self._mem_stats_collector and self._warmup:
if self.use_static_memstats:
self._mem_stats_collector.init_mem_stats(*args)
self._warmup = False
else:
self._mem_stats_collector.start_collection()
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes
......
......@@ -9,6 +9,16 @@ __all__ = ['RuntimeMemTracer']
class RuntimeMemTracer():
"""RuntimeMemTracer for the module training using ColoParameter.
Trace non-model memory usage during fwd+bwd process.
It is obtained by using a tensor with the same shape as the training process as the inputs
and running an single fwd+bwd to trace the statistics.
NOTE()
1. The premise to use this tracer is that the target DNN execute the same operations at each iterations,
2. Module buffers are viewed as non-model data.
"""
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
......
......@@ -50,5 +50,5 @@ class GeminiDDP(ZeroDDP):
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager, module)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
......@@ -117,7 +117,7 @@ def run_1d_hybrid_tp(model_name):
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
assert torch.allclose(loss, loss_torch, rtol=1e-2)
assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
torch.distributed.barrier()
loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment