"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c25582d5090cfdcf5f618df4733feecc7a7653b6"
Unverified Commit 7bcd72a2 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Make config objects json serializable (#862)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent fa87a73a
...@@ -3,12 +3,15 @@ Copyright (c) Microsoft Corporation ...@@ -3,12 +3,15 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license. Licensed under the MIT license.
""" """
from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.profiling.constants import * from deepspeed.profiling.constants import *
class DeepSpeedFlopsProfilerConfig(object): class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
def __init__(self, param_dict): def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__() super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None self.enabled = None
...@@ -24,6 +27,9 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -24,6 +27,9 @@ class DeepSpeedFlopsProfilerConfig(object):
self._initialize(flops_profiler_dict) self._initialize(flops_profiler_dict)
def _initialize(self, flops_profiler_dict): def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict, self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT) FLOPS_PROFILER_ENABLED_DEFAULT)
......
...@@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation ...@@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license. Licensed under the MIT license.
""" """
from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
######################################### #########################################
# DeepSpeed Activation Checkpointing # DeepSpeed Activation Checkpointing
...@@ -56,7 +56,7 @@ ACT_CHKPT_DEFAULT = { ...@@ -56,7 +56,7 @@ ACT_CHKPT_DEFAULT = {
} }
class DeepSpeedActivationCheckpointingConfig(object): class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
def __init__(self, param_dict): def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__() super(DeepSpeedActivationCheckpointingConfig, self).__init__()
...@@ -74,13 +74,6 @@ class DeepSpeedActivationCheckpointingConfig(object): ...@@ -74,13 +74,6 @@ class DeepSpeedActivationCheckpointingConfig(object):
self._initialize(act_chkpt_config_dict) self._initialize(act_chkpt_config_dict)
"""
For json serialization
"""
def repr(self):
return self.__dict__
def _initialize(self, act_chkpt_config_dict): def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param( self.partition_activations = get_scalar_param(
act_chkpt_config_dict, act_chkpt_config_dict,
......
...@@ -5,10 +5,21 @@ Licensed under the MIT license. ...@@ -5,10 +5,21 @@ Licensed under the MIT license.
""" """
Collection of DeepSpeed configuration utilities Collection of DeepSpeed configuration utilities
""" """
import json
from collections import Counter from collections import Counter
class DeepSpeedConfigObject(object):
"""
For json serialization
"""
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
def get_scalar_param(param_dict, param_name, param_default_value): def get_scalar_param(param_dict, param_name, param_default_value):
return param_dict.get(param_name, param_default_value) return param_dict.get(param_name, param_default_value)
......
...@@ -3,13 +3,12 @@ Copyright (c) Microsoft Corporation ...@@ -3,13 +3,12 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license. Licensed under the MIT license.
""" """
from deepspeed.runtime.config_utils import get_scalar_param from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.runtime.zero.constants import * from deepspeed.runtime.zero.constants import *
import json
class DeepSpeedZeroConfig(object): class DeepSpeedZeroConfig(DeepSpeedConfigObject):
def __init__(self, param_dict): def __init__(self, param_dict):
super(DeepSpeedZeroConfig, self).__init__() super(DeepSpeedZeroConfig, self).__init__()
...@@ -66,16 +65,6 @@ class DeepSpeedZeroConfig(object): ...@@ -66,16 +65,6 @@ class DeepSpeedZeroConfig(object):
.format(ZERO_FORMAT)) .format(ZERO_FORMAT))
return zero_config_dict return zero_config_dict
"""
For json serialization
"""
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
def _initialize(self, zero_config_dict): def _initialize(self, zero_config_dict):
self.stage = get_scalar_param(zero_config_dict, self.stage = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_STAGE, ZERO_OPTIMIZATION_STAGE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment