"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "5be118f405fccc0dcad3820ecff1f9d4d93c9a11"
Unverified Commit b87496a6 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] fix auto policy of test_sharded_optim_v2 (#2157)

parent 16335cb5
...@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector): ...@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
@property @property
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda') return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
...@@ -107,14 +107,6 @@ class MemStats(object): ...@@ -107,14 +107,6 @@ class MemStats(object):
else: else:
raise TypeError raise TypeError
def max_overall_cuda(self, device_type: str) -> float:
if device_type == 'cuda':
return max(self._overall_cuda_list)
elif device_type == 'cpu':
return max(self._overall_cpu_list)
else:
raise TypeError
def clear(self): def clear(self):
self._model_data_cuda_list = [] self._model_data_cuda_list = []
self._overall_cuda_list = [] self._overall_cuda_list = []
......
...@@ -79,9 +79,7 @@ class MemStatsCollector: ...@@ -79,9 +79,7 @@ class MemStatsCollector:
if self._start_flag and not self.use_outside_memstats: if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12 # The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] self._memstats.record_max_cuda_model_data(cuda_mem)
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
def sample_overall_data(self) -> None: def sample_overall_data(self) -> None:
""" """
......
...@@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g ...@@ -64,7 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
zero_model = ShardedModelV2( zero_model = ShardedModelV2(
zero_model, zero_model,
shard_strategy, shard_strategy,
tensor_placement_policy='cpu' if cpu_offload else 'cuda', tensor_placement_policy='cpu' if cpu_offload else 'auto',
reuse_fp16_shard=use_cpuadam, reuse_fp16_shard=use_cpuadam,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment