Unverified Commit e2dfe0d1 authored by Cheng Li's avatar Cheng Li Committed by GitHub
Browse files

Add flops profiler tutorial (#682)

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* fix tailing ws

* fix names

* remove multistep profiling and update docs

* fix cases where functionals and submodules coexist in a parent module, update readme

* fix typo

* always invoke post hook function

* fix module flops sum and update tests

* update tutorial
parent 6ee3b296
============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-6.0.1, py-1.9.0, pluggy-0.13.1
rootdir: /home/chengli1/projects/DeepSpeed
plugins: forked-1.3.0, hypothesis-5.41.3, xdist-2.1.0, cov-2.10.1
collected 0 items
============================ no tests ran in 0.01s =============================
...@@ -15,8 +15,7 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -15,8 +15,7 @@ class DeepSpeedFlopsProfilerConfig(object):
super(DeepSpeedFlopsProfilerConfig, self).__init__() super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None self.enabled = None
self.start_step = None self.profile_step = None
self.end_step = None
self.module_depth = None self.module_depth = None
self.top_modules = None self.top_modules = None
...@@ -35,13 +34,9 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -35,13 +34,9 @@ class DeepSpeedFlopsProfilerConfig(object):
FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT) FLOPS_PROFILER_ENABLED_DEFAULT)
self.start_step = get_scalar_param(flops_profiler_dict, self.profile_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_START_STEP, FLOPS_PROFILER_PROFILE_STEP,
FLOPS_PROFILER_START_STEP_DEFAULT) FLOPS_PROFILER_PROFILE_STEP_DEFAULT)
self.end_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_END_STEP,
FLOPS_PROFILER_END_STEP_DEFAULT)
self.module_depth = get_scalar_param(flops_profiler_dict, self.module_depth = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_MODULE_DEPTH, FLOPS_PROFILER_MODULE_DEPTH,
...@@ -50,3 +45,7 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -50,3 +45,7 @@ class DeepSpeedFlopsProfilerConfig(object):
self.top_modules = get_scalar_param(flops_profiler_dict, self.top_modules = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_TOP_MODULES, FLOPS_PROFILER_TOP_MODULES,
FLOPS_PROFILER_TOP_MODULES_DEFAULT) FLOPS_PROFILER_TOP_MODULES_DEFAULT)
self.detailed = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_DETAILED,
FLOPS_PROFILER_DETAILED_DEFAULT)
...@@ -12,11 +12,11 @@ FLOPS_PROFILER_FORMAT = ''' ...@@ -12,11 +12,11 @@ FLOPS_PROFILER_FORMAT = '''
flops profiler should be enabled as: flops profiler should be enabled as:
"session_params": { "session_params": {
"flops_profiler": { "flops_profiler": {
"enalbe": [true|false], "enabled": true,
"start_step": 5, "profile_step": 1,
"end_step": 6,
"module_depth": -1, "module_depth": -1,
"top_modules": 3, "top_modules": 3,
"detailed": true,
} }
} }
''' '''
...@@ -26,14 +26,14 @@ FLOPS_PROFILER = "flops_profiler" ...@@ -26,14 +26,14 @@ FLOPS_PROFILER = "flops_profiler"
FLOPS_PROFILER_ENABLED = "enabled" FLOPS_PROFILER_ENABLED = "enabled"
FLOPS_PROFILER_ENABLED_DEFAULT = False FLOPS_PROFILER_ENABLED_DEFAULT = False
FLOPS_PROFILER_START_STEP = "start_step" FLOPS_PROFILER_PROFILE_STEP = "profile_step"
FLOPS_PROFILER_START_STEP_DEFAULT = 5 FLOPS_PROFILER_PROFILE_STEP_DEFAULT = 1
FLOPS_PROFILER_END_STEP = "end_step"
FLOPS_PROFILER_END_STEP_DEFAULT = FLOPS_PROFILER_START_STEP_DEFAULT + 1
FLOPS_PROFILER_MODULE_DEPTH = "module_depth" FLOPS_PROFILER_MODULE_DEPTH = "module_depth"
FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1 FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1
FLOPS_PROFILER_TOP_MODULES = "top_modules" FLOPS_PROFILER_TOP_MODULES = "top_modules"
FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3 FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3
FLOPS_PROFILER_DETAILED = "detailed"
FLOPS_PROFILER_DETAILED_DEFAULT = True
...@@ -9,9 +9,9 @@ old_functions = {} ...@@ -9,9 +9,9 @@ old_functions = {}
class FlopsProfiler(object): class FlopsProfiler(object):
"""Measures the time, number of estimated flops and parameters of each module in a PyTorch model. """Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken. The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
Args: Args:
object (torch.nn.Module): The PyTorch model to profile. object (torch.nn.Module): The PyTorch model to profile.
...@@ -42,20 +42,15 @@ class FlopsProfiler(object): ...@@ -42,20 +42,15 @@ class FlopsProfiler(object):
# if computing the flops of the functionals in a module # if computing the flops of the functionals in a module
def pre_hook(module, input): def pre_hook(module, input):
module_flop_count.clear() module_flop_count.append([])
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
module.__steps__ += 1
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
def post_hook(module, input, output): def post_hook(module, input, output):
module.__flops__ += sum([elem[1] for elem in module_flop_count]) if module_flop_count:
module_flop_count.clear() module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
module_flop_count.pop()
has_children = len(module._modules.items()) != 0
if not has_children:
module.__post_hook_handle__ = module.register_forward_hook(post_hook) module.__post_hook_handle__ = module.register_forward_hook(post_hook)
def start_time_hook(module, input): def start_time_hook(module, input):
...@@ -77,8 +72,6 @@ class FlopsProfiler(object): ...@@ -77,8 +72,6 @@ class FlopsProfiler(object):
Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored. Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored.
""" """
def remove_profile_attrs(module): def remove_profile_attrs(module):
if hasattr(module, "__steps__"):
del module.__steps__
if hasattr(module, "__flops__"): if hasattr(module, "__flops__"):
del module.__flops__ del module.__flops__
if hasattr(module, "__params__"): if hasattr(module, "__params__"):
...@@ -117,100 +110,91 @@ class FlopsProfiler(object): ...@@ -117,100 +110,91 @@ class FlopsProfiler(object):
if p.requires_grad) if p.requires_grad)
module.__start_time__ = 0 module.__start_time__ = 0
module.__duration__ = 0 module.__duration__ = 0
module.__steps__ = 0
self.model.apply(add_or_reset_attrs) self.model.apply(add_or_reset_attrs)
def get_total_flops(self, in_str=False): def get_total_flops(self, as_string=False):
"""Returns the total flops of the model. """Returns the total flops of the model.
Args: Args:
in_str (bool, optional): whether to output the flops in string. Defaults to False. as_string (bool, optional): whether to output the flops as string. Defaults to False.
""" """
if self.get_total_steps() == 0: total_flops = get_module_flops(self.model)
return 0 return macs_to_string(total_flops) if as_string else total_flops
sum = 0
for module in self.model.modules():
sum += module.__flops__
total_flops = sum / self.get_total_steps()
return flops_to_string(total_flops) if in_str else total_flops
def get_total_duration(self, in_str=False): def get_total_duration(self, as_string=False):
"""Returns the total duration of the model forward pass. """Returns the total duration of the model forward pass.
Args: Args:
in_str (bool, optional): whether to output the duration in string. Defaults to False. as_string (bool, optional): whether to output the duration as string. Defaults to False.
""" """
if self.get_total_steps() == 0: total_duration = self.model.__duration__
return 0 return duration_to_string(total_duration) if as_string else total_duration
total_duration = self.model.__duration__ / self.get_total_steps()
return duration_to_string(total_duration) if in_str else total_duration
def get_total_params(self, in_str=False): def get_total_params(self, as_string=False):
"""Returns the total parameters of the model. """Returns the total parameters of the model.
Args: Args:
in_str (bool, optional): whether to output the parameters in string. Defaults to False. as_string (bool, optional): whether to output the parameters as string. Defaults to False.
""" """
return params_to_string( return params_to_string(
self.model.__params__) if in_str else self.model.__params__ self.model.__params__) if as_string else self.model.__params__
def get_total_steps(self):
"""Returns the total number of steps (or input batches) profiled.
"""
def get_steps(module):
if module.__steps__ == 0:
sum = 0
for m in module.children():
sum += get_steps(m)
module.__steps__ = sum
return module.__steps__
total_steps = get_steps(self.model)
if total_steps == 0:
print("no step is profiled")
return total_steps
def print_model_profile(self): def print_model_profile(self,
profile_step=1,
module_depth=-1,
top_modules=3,
detailed=True):
"""Prints the model graph with the measured profile attached to each module. """Prints the model graph with the measured profile attached to each module.
""" """
total_flops = self.get_total_flops() total_flops = self.get_total_flops()
total_duration = self.get_total_duration() total_duration = self.get_total_duration()
total_params = self.get_total_params() total_params = self.get_total_params()
total_steps = self.get_total_steps()
def accumulate_flops(module): self.flops = total_flops
has_children = len(module._modules.items()) != 0 self.params = total_params
if not has_children:
return module.__flops__ print(
else: "\n-------------------------- DeepSpeed Flops Profiler --------------------------"
sum = 0 )
for m in module.children(): print("Summary of forward pass:")
sum += m.accumulate_flops() print('{:<30} {:<8}'.format('Profile step: ', profile_step))
return sum print('{:<30} {:<8}'.format('Number of parameters: ',
params_to_string(total_params)))
print('{:<30} {:<8}'.format('Number of multiply-accumulate operations (MACs): ',
num_to_string(total_flops)))
print('{:<30} {:<8}'.format(
'Number of floating point operations ( = 2 * MACs): ',
num_to_string(2 * total_flops)))
print('{:<30} {:<8}'.format('Latency: ', duration_to_string(total_duration)))
print('{:<30} {:<8}'.format('Floating point operations per second(FLOPS): ',
flops_to_string(2 * total_flops / total_duration)))
def flops_repr(module): def flops_repr(module):
params = module.__params__ params = module.__params__
flops = 0 if total_steps == 0 else module.accumulate_flops() / total_steps flops = get_module_flops(module)
items = [ items = [
params_to_string(params), params_to_string(params),
"{:.2%} Params".format(params / total_params), "{:.2%} Params".format(params / total_params),
flops_to_string(flops), macs_to_string(flops),
"{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops), "{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops),
] ]
duration = 0 if total_steps == 0 else module.__duration__ / total_steps duration = module.__duration__
if duration == 0: # e.g. ModuleList
for m in module.children():
duration += m.__duration__
items.append(duration_to_string(duration)) items.append(duration_to_string(duration))
items.append("{:.2%} time".format(0.0 if total_duration == 0 else duration / items.append(
"{:.2%} latency".format(0.0 if total_duration == 0 else duration /
total_duration)) total_duration))
# flops = 2 * MACs # flops = 2 * MACs
items.append(("{:.2} TFLOPS".format(0.0 if duration == 0 else 2 * flops / items.append(flops_to_string(0.0 if duration == 0 else 2 * flops / duration))
duration / 10**12)))
items.append(str(module.__steps__))
items.append(module.original_extra_repr()) items.append(module.original_extra_repr())
return ", ".join(items) return ", ".join(items)
def add_extra_repr(module): def add_extra_repr(module):
module.accumulate_flops = accumulate_flops.__get__(module)
flops_extra_repr = flops_repr.__get__(module) flops_extra_repr = flops_repr.__get__(module)
if module.extra_repr != flops_extra_repr: if module.extra_repr != flops_extra_repr:
module.original_extra_repr = module.extra_repr module.original_extra_repr = module.extra_repr
...@@ -221,13 +205,33 @@ class FlopsProfiler(object): ...@@ -221,13 +205,33 @@ class FlopsProfiler(object):
if hasattr(module, "original_extra_repr"): if hasattr(module, "original_extra_repr"):
module.extra_repr = module.original_extra_repr module.extra_repr = module.original_extra_repr
del module.original_extra_repr del module.original_extra_repr
if hasattr(module, "accumulate_flops"):
del module.accumulate_flops
self.model.apply(add_extra_repr) self.model.apply(add_extra_repr)
print(
"\n----------------------------- Aggregated Profile -----------------------------"
)
self.print_model_aggregated_profile(module_depth=module_depth,
top_modules=top_modules)
if detailed:
print(
"\n------------------------------ Detailed Profile ------------------------------"
)
print(
"Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
)
print(
"Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n"
)
print(self.model) print(self.model)
self.model.apply(del_extra_repr) self.model.apply(del_extra_repr)
print(
"------------------------------------------------------------------------------"
)
def print_model_aggregated_profile(self, module_depth=-1, top_modules=3): def print_model_aggregated_profile(self, module_depth=-1, top_modules=3):
"""Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth. """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
...@@ -236,9 +240,6 @@ class FlopsProfiler(object): ...@@ -236,9 +240,6 @@ class FlopsProfiler(object):
top_modules (int, optional): the number of top modules to show. Defaults to 3. top_modules (int, optional): the number of top modules to show. Defaults to 3.
""" """
info = {} info = {}
total_steps = self.get_total_steps()
if total_steps == 0:
return
if not hasattr(self.model, "__flops__"): if not hasattr(self.model, "__flops__"):
print( print(
"no __flops__ attribute in the model, call this function after start_profile and before end_profile" "no __flops__ attribute in the model, call this function after start_profile and before end_profile"
...@@ -271,7 +272,7 @@ class FlopsProfiler(object): ...@@ -271,7 +272,7 @@ class FlopsProfiler(object):
num_items = min(top_modules, len(info[depth])) num_items = min(top_modules, len(info[depth]))
sort_flops = { sort_flops = {
k: flops_to_string(v[0] / total_steps) k: macs_to_string(v[0])
for k, for k,
v in sorted(info[depth].items(), v in sorted(info[depth].items(),
key=lambda item: item[1][0], key=lambda item: item[1][0],
...@@ -285,15 +286,15 @@ class FlopsProfiler(object): ...@@ -285,15 +286,15 @@ class FlopsProfiler(object):
reverse=True)[:num_items] reverse=True)[:num_items]
} }
sort_time = { sort_time = {
k: duration_to_string(v[2] / total_steps) k: duration_to_string(v[2])
for k, for k,
v in sorted(info[depth].items(), v in sorted(info[depth].items(),
key=lambda item: item[1][2], key=lambda item: item[1][2],
reverse=True)[:num_items] reverse=True)[:num_items]
} }
print(f"Top {num_items} modules in flops at depth {depth} are {sort_flops}") print(f"Top {num_items} modules in MACs at depth {depth} are {sort_flops}")
print(f"Top {num_items} modules in params at depth {depth} are {sort_params}") print(f"Top {num_items} modules in params at depth {depth} are {sort_params}")
print(f"Top {num_items} modules in time at depth {depth} are {sort_time}") print(f"Top {num_items} modules in latency at depth {depth} are {sort_time}")
def _prod(dims): def _prod(dims):
...@@ -461,7 +462,8 @@ def wrapFunc(func, funcFlopCompute): ...@@ -461,7 +462,8 @@ def wrapFunc(func, funcFlopCompute):
def newFunc(*args, **kwds): def newFunc(*args, **kwds):
flops = funcFlopCompute(*args, **kwds) flops = funcFlopCompute(*args, **kwds)
module_flop_count.append((name, flops)) if module_flop_count:
module_flop_count[-1].append((name, flops))
return oldFunc(*args, **kwds) return oldFunc(*args, **kwds)
return newFunc return newFunc
...@@ -630,25 +632,61 @@ MODULE_HOOK_MAPPING = { ...@@ -630,25 +632,61 @@ MODULE_HOOK_MAPPING = {
} }
def num_to_string(num, precision=2):
if num // 10**9 > 0:
return str(round(num / 10.0**9, precision)) + " G"
elif num // 10**6 > 0:
return str(round(num / 10.0**6, precision)) + " M"
elif num // 10**3 > 0:
return str(round(num / 10.0**3, precision)) + " K"
else:
return str(num)
def macs_to_string(macs, units=None, precision=2):
if units is None:
if macs // 10**9 > 0:
return str(round(macs / 10.0**9, precision)) + " GMACs"
elif macs // 10**6 > 0:
return str(round(macs / 10.0**6, precision)) + " MMACs"
elif macs // 10**3 > 0:
return str(round(macs / 10.0**3, precision)) + " KMACs"
else:
return str(macs) + " MACs"
else:
if units == "GMACs":
return str(round(macs / 10.0**9, precision)) + " " + units
elif units == "MMACs":
return str(round(macs / 10.0**6, precision)) + " " + units
elif units == "KMACs":
return str(round(macs / 10.0**3, precision)) + " " + units
else:
return str(macs) + " MACs"
def flops_to_string(flops, units=None, precision=2): def flops_to_string(flops, units=None, precision=2):
if units is None: if units is None:
if flops // 10**12 > 0:
return str(round(flops / 10.0**12, precision)) + " TFLOPS"
if flops // 10**9 > 0: if flops // 10**9 > 0:
return str(round(flops / 10.0**9, precision)) + " GMACs" return str(round(flops / 10.0**9, precision)) + " GFLOPS"
elif flops // 10**6 > 0: elif flops // 10**6 > 0:
return str(round(flops / 10.0**6, precision)) + " MMACs" return str(round(flops / 10.0**6, precision)) + " MFLOPS"
elif flops // 10**3 > 0: elif flops // 10**3 > 0:
return str(round(flops / 10.0**3, precision)) + " KMACs" return str(round(flops / 10.0**3, precision)) + " KFLOPS"
else: else:
return str(flops) + " MACs" return str(flops) + " FLOPS"
else: else:
if units == "GMACs": if units == "TFLOPS":
return str(round(flops / 10.0**12, precision)) + " " + units
if units == "GFLOPS":
return str(round(flops / 10.0**9, precision)) + " " + units return str(round(flops / 10.0**9, precision)) + " " + units
elif units == "MMACs": elif units == "MFLOPS":
return str(round(flops / 10.0**6, precision)) + " " + units return str(round(flops / 10.0**6, precision)) + " " + units
elif units == "KMACs": elif units == "KFLOPS":
return str(round(flops / 10.0**3, precision)) + " " + units return str(round(flops / 10.0**3, precision)) + " " + units
else: else:
return str(flops) + " MACs" return str(flops) + " FLOPS"
def params_to_string(params_num, units=None, precision=2): def params_to_string(params_num, units=None, precision=2):
...@@ -687,32 +725,40 @@ def duration_to_string(duration, units=None, precision=2): ...@@ -687,32 +725,40 @@ def duration_to_string(duration, units=None, precision=2):
return str(round(duration, precision)) + " s" return str(round(duration, precision)) + " s"
# can not iterate over all submodules using self.model.modules()
# since modules() returns duplicate modules only once
def get_module_flops(module):
sum = module.__flops__
# iterate over immediate children modules
for child in module.children():
sum += get_module_flops(child)
return sum
def get_model_profile( def get_model_profile(
model, model,
input_res, input_res,
input_constructor=None, input_constructor=None,
print_profile=True, print_profile=True,
print_aggregated_profile=True, detailed=True,
module_depth=-1, module_depth=-1,
top_modules=3, top_modules=3,
warm_up=5, warm_up=1,
num_steps=10, as_string=True,
as_strings=True,
ignore_modules=None, ignore_modules=None,
): ):
"""Returns the total flops, parameters, and profiled steps of a model. """Returns the total MACs and parameters of a model.
Args: Args:
model ([torch.nn.Module]): the PyTorch model to be profiled. model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor input_res (list): input shape or input to the input_constructor
input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None. input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None.
print_profile (bool, optional): whether to print the model graph with the profile annotated. Defaults to True. print_profile (bool, optional): whether to print the model profile. Defaults to True.
print_aggregated_profile (bool, optional): whether to print the aggregated profile for top modules. Defaults to True. detailed (bool, optional): whether to print the detailed model profile. Defaults to True.
module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules). module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules).
top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3. top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3.
warm_up (int, optional): the number of warm-up steps before measuring the time of each module. Defaults to 5. warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
num_steps (int, optional): the number of steps to profile. Defaults to 10. as_string (bool, optional): whether to print the output as string. Defaults to True.
as_strings (bool, optional): whether to print the output as strings. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
""" """
assert type(input_res) is tuple assert type(input_res) is tuple
...@@ -738,7 +784,6 @@ def get_model_profile( ...@@ -738,7 +784,6 @@ def get_model_profile(
prof.start_profile(ignore_list=ignore_modules) prof.start_profile(ignore_list=ignore_modules)
for _ in range(num_steps):
if input_constructor: if input_constructor:
input = input_constructor(input_res) input = input_constructor(input_res)
_ = model(**input) _ = model(**input)
...@@ -756,14 +801,14 @@ def get_model_profile( ...@@ -756,14 +801,14 @@ def get_model_profile(
flops = prof.get_total_flops() flops = prof.get_total_flops()
params = prof.get_total_params() params = prof.get_total_params()
steps = prof.get_total_steps()
if print_profile: if print_profile:
prof.print_model_profile() prof.print_model_profile(profile_step=warm_up,
if print_aggregated_profile: module_depth=module_depth,
prof.print_model_aggregated_profile(module_depth=module_depth, top_modules=top_modules,
top_modules=top_modules) detailed=detailed)
prof.end_profile() prof.end_profile()
if as_strings: if as_string:
return flops_to_string(flops), params_to_string(params), steps return macs_to_string(flops), params_to_string(params)
return flops, params, steps return flops, params
...@@ -277,11 +277,8 @@ class DeepSpeedEngine(Module): ...@@ -277,11 +277,8 @@ class DeepSpeedEngine(Module):
def flops_profiler_enabled(self): def flops_profiler_enabled(self):
return self._config.flops_profiler_config.enabled return self._config.flops_profiler_config.enabled
def flops_profiler_start_step(self): def flops_profiler_profile_step(self):
return self._config.flops_profiler_config.start_step return self._config.flops_profiler_config.profile_step
def flops_profiler_end_step(self):
return self._config.flops_profiler_config.end_step
def flops_profiler_module_depth(self): def flops_profiler_module_depth(self):
return self._config.flops_profiler_config.module_depth return self._config.flops_profiler_config.module_depth
...@@ -289,6 +286,9 @@ class DeepSpeedEngine(Module): ...@@ -289,6 +286,9 @@ class DeepSpeedEngine(Module):
def flops_profiler_top_modules(self): def flops_profiler_top_modules(self):
return self._config.flops_profiler_config.top_modules return self._config.flops_profiler_config.top_modules
def flops_profiler_detailed(self):
return self._config.flops_profiler_config.detailed
def memory_breakdown(self): def memory_breakdown(self):
return self._config.memory_breakdown return self._config.memory_breakdown
...@@ -799,30 +799,11 @@ class DeepSpeedEngine(Module): ...@@ -799,30 +799,11 @@ class DeepSpeedEngine(Module):
**kwargs: variable length keyword arguments **kwargs: variable length keyword arguments
""" """
if self.flops_profiler_enabled( if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_start_step( ) and self.global_steps == self.flops_profiler_profile_step(
) and self.global_rank == 0: ) and self.global_rank == 0:
self.flops_profiler = FlopsProfiler(self.module) self.flops_profiler = FlopsProfiler(self.module)
self.flops_profiler.start_profile(ignore_list=None) self.flops_profiler.start_profile(ignore_list=None)
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_end_step(
) and self.global_rank == 0:
print('{:<30} {:<8}'.format(
'Number of multiply-adds: ',
self.flops_profiler.get_total_flops(in_str=False)))
print('{:<30} {:<8}'.format(
'Number of parameters: ',
self.flops_profiler.get_total_params(in_str=False)))
print('{:<30} {:<8}'.format('Number of steps profiled: ',
self.flops_profiler.get_total_steps()))
self.flops_profiler.print_model_profile()
self.flops_profiler.print_model_aggregated_profile(
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules())
self.flops_profiler.flops = self.flops_profiler.get_total_flops()
self.flops_profiler.params = self.flops_profiler.get_total_params()
self.flops_profiler.end_profile()
if self.module.training and self.progressive_layer_drop: if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state()) kwargs.update(self.progressive_layer_drop.get_state())
...@@ -838,6 +819,16 @@ class DeepSpeedEngine(Module): ...@@ -838,6 +819,16 @@ class DeepSpeedEngine(Module):
self.timers('forward').stop() self.timers('forward').stop()
self.timers('forward_microstep').stop() self.timers('forward_microstep').stop()
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_profile_step(
) and self.global_rank == 0:
self.flops_profiler.print_model_profile(
profile_step=self.global_steps,
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules(),
detailed=self.flops_profiler_detailed())
self.flops_profiler.end_profile()
return loss return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
......
...@@ -41,6 +41,7 @@ collections: ...@@ -41,6 +41,7 @@ collections:
- 1Cycle.md - 1Cycle.md
- lrrt.md - lrrt.md
- zero.md - zero.md
- flops-profiler.md
defaults: defaults:
- scope: - scope:
......
...@@ -45,6 +45,8 @@ lnav: ...@@ -45,6 +45,8 @@ lnav:
url: /docs/config-json/#zero-optimizations-for-fp16-training url: /docs/config-json/#zero-optimizations-for-fp16-training
- title: "Logging" - title: "Logging"
url: /docs/config-json/#logging url: /docs/config-json/#logging
- title: "Flops Profiler"
url: /docs/config-json/#flops-profiler
- title: "Activation checkpointing" - title: "Activation checkpointing"
url: /docs/config-json/#activation-checkpointing url: /docs/config-json/#activation-checkpointing
- title: "Sparse Attention" - title: "Sparse Attention"
...@@ -84,5 +86,7 @@ lnav: ...@@ -84,5 +86,7 @@ lnav:
url: /tutorials/pipeline/ url: /tutorials/pipeline/
- title: "Progressive Layer Dropping" - title: "Progressive Layer Dropping"
url: /tutorials/progressive_layer_dropping/ url: /tutorials/progressive_layer_dropping/
- title: "Flops Profiler"
url: /tutorials/flops-profiler/
- title: "Contributing" - title: "Contributing"
url: /contributing/ url: /contributing/
This diff is collapsed.
...@@ -240,19 +240,53 @@ comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed dat ...@@ -240,19 +240,53 @@ comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed dat
can automatically handle batch creation appropriately. can automatically handle batch creation appropriately.
## Performance Analysis and Debugging ## Performance Analysis and Debugging
For performance debugging, DeepSpeed can give you a detailed breakdown of the time spent
in different parts of the training by simply enabling it in the `deepspeed_config` DeepSpeed provides a set of tools for performance analysis and debugging.
file.
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details. ### Wall Clock Breakdown
DeepSpeed provides a detailed breakdown of the time spent
in different parts of the training.
This can be enabled by setting the following in the `deepspeed_config` file.
```json ```json
{ {
"wall_clock_breakdown": true, "wall_clock_breakdown": true,
}
```
### Timing Activiation Checkpoint Functions
When activiation checkpoingint is enabled, profiling the forward and backward time of each checkpoint function can be enabled in the `deepspeed_config` file.
```json
{
"activation_checkpointing": { "activation_checkpointing": {
"profile": true "profile": true
} }
} }
```
### Flops Profiler
The DeepSpeed flops profiler measures the time, flops and parameters of a PyTorch model and shows which modules or layers are the bottleneck. When used with the DeepSpeed runtime, the flops profiler can be configured in the `deepspeed_config` file as follows:
```json
{
"flops_profiler": {
"enabled": true,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true,
}
}
``` ```
The flops profiler can also be used as a standalone package. Please refer to the [Flops Profiler](/tutorials/flops-profiler) tutorial for more details.
## Sparse Attention ## Sparse Attention
DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial. DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial.
......
This diff is collapsed.
...@@ -24,8 +24,7 @@ def test_flops_profiler_in_ds_trainning(tmpdir): ...@@ -24,8 +24,7 @@ def test_flops_profiler_in_ds_trainning(tmpdir):
}, },
"flops_profiler": { "flops_profiler": {
"enabled": True, "enabled": True,
"start_step": 2, "step": 1,
"end_step": 3,
"module_depth": -1, "module_depth": -1,
"top_modules": 3, "top_modules": 3,
}, },
...@@ -100,18 +99,17 @@ def test_flops_profiler_in_inference(): ...@@ -100,18 +99,17 @@ def test_flops_profiler_in_inference():
mod = LeNet5(10) mod = LeNet5(10)
batch_size = 1024 batch_size = 1024
input = torch.randn(batch_size, 1, 32, 32) input = torch.randn(batch_size, 1, 32, 32)
macs, params, steps = get_model_profile( macs, params = get_model_profile(
mod, mod,
tuple(input.shape), tuple(input.shape),
print_profile=True, print_profile=True,
print_aggregated_profile=True, detailed=True,
module_depth=-1, module_depth=-1,
top_modules=3, top_modules=3,
warm_up=5, warm_up=1,
num_steps=10, as_string=True,
as_strings=True,
ignore_modules=None, ignore_modules=None,
) )
print(macs, params, steps) print(macs, params)
assert macs == "439.55 MMACs" assert macs == "439.56 MMACs"
assert params == "61.71 k" assert params == "61.71 k"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment