"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "cd96a84add1e850050a746715be0c4763f549c32"
Unverified Commit 7bcd72a2 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Make config objects json serializable (#862)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent fa87a73a
......@@ -3,12 +3,15 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.profiling.constants import *
class DeepSpeedFlopsProfilerConfig(object):
class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None
......@@ -24,6 +27,9 @@ class DeepSpeedFlopsProfilerConfig(object):
self._initialize(flops_profiler_dict)
def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
......
......@@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
#########################################
# DeepSpeed Activation Checkpointing
......@@ -56,7 +56,7 @@ ACT_CHKPT_DEFAULT = {
}
class DeepSpeedActivationCheckpointingConfig(object):
class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__()
......@@ -74,13 +74,6 @@ class DeepSpeedActivationCheckpointingConfig(object):
self._initialize(act_chkpt_config_dict)
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param(
act_chkpt_config_dict,
......
......@@ -5,10 +5,21 @@ Licensed under the MIT license.
"""
Collection of DeepSpeed configuration utilities
"""
import json
from collections import Counter
class DeepSpeedConfigObject(object):
"""
For json serialization
"""
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
def get_scalar_param(param_dict, param_name, param_default_value):
return param_dict.get(param_name, param_default_value)
......
......@@ -3,13 +3,12 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.utils import logger
from deepspeed.runtime.zero.constants import *
import json
class DeepSpeedZeroConfig(object):
class DeepSpeedZeroConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedZeroConfig, self).__init__()
......@@ -66,16 +65,6 @@ class DeepSpeedZeroConfig(object):
.format(ZERO_FORMAT))
return zero_config_dict
"""
For json serialization
"""
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
def _initialize(self, zero_config_dict):
self.stage = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_STAGE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment