Unverified Commit e2fbe4d2 authored by Cheng Li's avatar Cheng Li Committed by GitHub
Browse files

squash latest flops profiling changes (#1) (#664)


Co-authored-by: default avatarCheng Li <pistasable@gmail.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent adcfd269
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.profiling.constants import *
class DeepSpeedFlopsProfilerConfig(object):
def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None
self.start_step = None
self.end_step = None
self.module_depth = None
self.top_modules = None
if FLOPS_PROFILER in param_dict.keys():
flops_profiler_dict = param_dict[FLOPS_PROFILER]
else:
flops_profiler_dict = {}
self._initialize(flops_profiler_dict)
def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
self.start_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_START_STEP,
FLOPS_PROFILER_START_STEP_DEFAULT)
self.end_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_END_STEP,
FLOPS_PROFILER_END_STEP_DEFAULT)
self.module_depth = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_MODULE_DEPTH,
FLOPS_PROFILER_MODULE_DEPTH_DEFAULT)
self.top_modules = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_TOP_MODULES,
FLOPS_PROFILER_TOP_MODULES_DEFAULT)
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
#########################################
# flops profiler
#########################################
# Flops profiler. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
FLOPS_PROFILER_FORMAT = '''
flops profiler should be enabled as:
"session_params": {
"flops_profiler": {
"enalbe": [true|false],
"start_step": 5,
"end_step": 6,
"module_depth": -1,
"top_modules": 3,
}
}
'''
FLOPS_PROFILER = "flops_profiler"
FLOPS_PROFILER_ENABLED = "enabled"
FLOPS_PROFILER_ENABLED_DEFAULT = False
FLOPS_PROFILER_START_STEP = "start_step"
FLOPS_PROFILER_START_STEP_DEFAULT = 5
FLOPS_PROFILER_END_STEP = "end_step"
FLOPS_PROFILER_END_STEP_DEFAULT = FLOPS_PROFILER_START_STEP_DEFAULT + 1
FLOPS_PROFILER_MODULE_DEPTH = "module_depth"
FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1
FLOPS_PROFILER_TOP_MODULES = "top_modules"
FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3
# flops-profiler
> Measures the time, number of estimated flops and parameters of each module in a PyTorch Model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken.
The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that flops-profiler captures `torch.nn.functional` invoked in a module to estimate the flops, thus allowing customized modules in the model (e.g. `ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). The flops-profiler also supports flops computation at module level (for RNNs).
For models running on multi-node or multi-gpu, only the model parallelism affects the number of flops and parameters (e.g. `--model-parallel-size` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)), i.e., model_parallel_size _ flops = total_flops, model_parallel_size _ parameters = total_parameters. The number of gpus or nodes does not affect the output profile.
Below is an example output for LeNet5 with batch size 1024 on a V100 GPU:
```
LeNet5(
61.71 k, 100.00% Params, 439.55 MMACs, 100.00% MACs, 25.62 ms, 100.00% time, 0.034 TFLOPS,
(feature_extractor): Sequential(
50.69 k, 82.15% Params, 428.37 MMACs, 97.46% MACs, 18.41 ms, 71.85% time, 0.047 TFLOPS,
(0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 10.56 ms, 41.21% time, 0.024 TFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1))
(1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 2.25 ms, 8.79% time, 0.0 TFLOPS, )
(2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 2.47 ms, 9.63% time, 0.0039 TFLOPS, kernel_size=2, stride=2, padding=0)
(3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 1.08 ms, 4.23% time, 0.46 TFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 497.39 us, 1.94% time, 0.0 TFLOPS, )
(5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 758.24 us, 2.96% time, 0.0043 TFLOPS, kernel_size=2, stride=2, padding=0)
(6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 606.35 us, 2.37% time, 0.16 TFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1))
(7): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 68.86 us, 0.27% time, 0.0 TFLOPS, )
)
(classifier): Sequential(
11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 7.03 ms, 27.43% time, 0.0032 TFLOPS,
(0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.71 ms, 10.57% time, 0.0076 TFLOPS, in_features=120, out_features=84, bias=True)
(1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 78.77 us, 0.31% time, 0.0 TFLOPS, )
(2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 4.17 ms, 16.27% time, 0.00041 TFLOPS, in_features=84, out_features=10, bias=True)
)
)
Top 3 modules in flops at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'}
Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'}
Top 3 modules in time at depth 2 are {'Conv2d': '12.25 ms', 'Linear': '6.88 ms', 'AvgPool2d': '3.23 ms'}
Batch size: 1024
Number of multiply-adds: 439.55 MMACs
Number of parameters: 61.71 k
Number of steps profiled: 10
```
## Installation
The profiler is an integral part of DeepSpeed and can be installed by
```
pip install deepspeed
```
Refer to the [installaiton of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for more information.
## Usage
### With the DeepSpeed runtime
If using DeepSpeed for model training, no explict API calls are needed to use the flops-profiler.
In DeepSpeed config file, specify:
```python
ds_config = {
...# other deepspeed configs
"flops_profiler": {
"enabled": True,
"start_step": 2,
"end_step": 3,
"module_depth": -1,
"top_modules": 3,
},
}
```
- `"enabled": true` to enable the flops-profiler.
- `"start_step": 5` to start the profiler at step 5. Note that warm-up is necessary for getting accurate timing information.
- `"end_step": 6` to end the profiler at step 6. Note that `end_step > start_step`.
- `"module_depth": -1` to print aggregated module information at the maximum depth (innermost modules). Can be set to any positive number, caped by the maximum depth of the model.
- `"top_modules": 3`to set the number of top modules to print aggregated profile
An example is given in [test_flops_profiler](tests/unit/test_flops_profiler.py).
### Without the DeepSpeed runtime
The flops-profiler can be used as a standalone package outside of the deepspeed runtime.
#### Use the low-level APIs to profile the forward pass in the existing model training workflow
- `start_profile` - starts profiling
- `get_total_flops` - returns the total number of flops
- `get_total_params` - returns the total number of params
- `get_total_duration` - returns the total duration of the model forward pass
- `get_total_steps` - returns the total number of steps (or input batches) profiled.
- `print_model_profile` - prints the profile annotated
- `print_model_aggregated_profile` - prints the aggregated profile for the top modules
- `end_profile` - ends profiling and cleans up, invoked at the end of the profiling and before any printing method.
`flops_to_string`, `params_to_string`, `duration_to_string` are utility functions to convert the metric number to string.
Below is an example of this usage in a typical training workflow.
```python
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
model = Model()
profiler = FlopsProfiler(model)
start_step = 5
end_step = 10
assert (end_step > start_step), "should end profiling after start profiling"
print_profile = True
pring_aggregated_profile = True
for step, batch in enumerate(data_loader):
# start profiling at training step "profile_step"
if step == start_step:
profiler.start_profile()
# end profiling and print output at training step "profile_step"
if model == end_step: # if using multi nodes, check global_rank == 0 as well
flops = profiler.get_total_flops()
params = profiler.get_total_flops()
duration = profiler.get_total_duration()
steps = profiler.get_total_steps()
if print_profile:
profiler.print_model_profile()
if print_aggregated_profile:
profiler.print_model_aggregated_profile(module_depth=-1, top_modules=3)
profiler.end_profile()
print(flops, params, duration, step)
# forward() method
loss = model(batch)
# runs backpropagation
loss.backward()
# weight update
optimizer.step()
```
#### Use the high level-API and run the model inference for profiling purpose
Examples of this usage are given below.
##### Classification model example:
```python
import argparse
import sys
import torch
import torchvision.models as models
from deepspeed.profiling.flops_profiler import get_model_profile
pt_models = {
'resnet18': models.resnet18,
'resnet50': models.resnet50,
'alexnet': models.alexnet,
'vgg16': models.vgg16,
'squeezenet': models.squeezenet1_0,
'densenet': models.densenet161,
'inception': models.inception_v3
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='flops-profiler example script')
parser.add_argument('--device',
type=int,
default=0,
help='Device to store the model.')
parser.add_argument('--model',
choices=list(pt_models.keys()),
type=str,
default='resnet18')
args = parser.parse_args()
model = pt_models[args.model]()
if torch.cuda.is_available():
model.cuda(device=args.device)
batch_size = 256
macs, params, steps = get_model_profile(model, # the PyTorch model to be profiled
input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor
input_constructor=None, # If specified, the constructor is applied to input_res and the constructor output is used as the input to the model
print_profile=True, # whether to print the model graph with the profile annotated. Defaults to True
print_aggregated_profile=True, # whether to print the aggregated profile for top modules. Defaults to True
module_depth=-1, # the depth into the nested modules. Defaults to -1 (the inner most modules)
top_modules=3, # the number of top modules to print aggregated profile
warm_up=10, # the number of warm-up steps before measuring the time of each module. Defaults to 5
num_steps=10, # the number of steps to profile. Defaults to 10
as_strings=True, # whether to print the output as strings (e.g. 1k). Defaults to True
ignore_modules=None) # the list of modules to ignore during profiling. Defaults to None
print("{:<30} {:<8}".format("Batch size: ", batch_size))
print('{:<30} {:<8}'.format('Number of MACs: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
print('{:<30} {:<8}'.format('Number of steps profiled: ', steps))
# Output:
# Number of MACs: 466.48 GMACs
# Number of parameters: 11.69 M
```
##### Bert model example:
```python
from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
def bert_input_constructor(input_shape, tokenizer):
inp_seq = ""
for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP]
inp_seq += tokenizer.pad_token # let's use pad token to form a fake
# sequence for subsequent flops calculation
inputs = tokenizer([inp_seq] * input_shape[0],
padding=True,
truncation=True,
return_tensors="pt")
labels = torch.tensor([1] * input_shape[0])
# Batch size input_shape[0], sequence length input_shape[128]
inputs = dict(inputs)
inputs.update({"labels": labels})
return inputs
if __name__ == '__main__':
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
macs, params, steps = get_model_profile(
model,
(2, 128),
input_constructor=partial(bert_input_constructor,
tokenizer=bert_tokenizer),
print_profile=True,
print_aggregated_profile=True,
)
print("{:<30} {:<8}".format("Number of multiply-adds: ", macs))
print("{:<30} {:<8}".format("Number of parameters: ", params))
print("{:<30} {:<8}".format("Number of steps profiled: ", steps))
# Output:
# Number of multiply-adds: 21.74 GMACs
# Number of parameters: 109.48 M
```
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
module_flop_count = []
old_functions = {}
class FlopsProfiler(object):
"""Measures the time, number of estimated flops and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken.
Args:
object (torch.nn.Module): The PyTorch model to profile.
"""
def __init__(self, model):
self.model = model
def start_profile(self, ignore_list=None):
"""Starts profiling.
Extra attributes are added recursively to all the modules and the profiled torch.nn.functionals are monkey patched.
Args:
ignore_list (list, optional): the list of modules to ignore while profiling. Defaults to None.
"""
self.reset_profile()
_patch_functionals()
def register_module_hooks(module, ignore_list):
if ignore_list and type(module) in ignore_list:
return
# if computing the flops of a module directly
if type(module) in MODULE_HOOK_MAPPING:
module.__flops_handle__ = module.register_forward_hook(
MODULE_HOOK_MAPPING[type(module)])
return
# if computing the flops of the functionals in a module
def pre_hook(module, input):
module_flop_count.clear()
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
module.__steps__ += 1
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
def post_hook(module, input, output):
module.__flops__ += sum([elem[1] for elem in module_flop_count])
module_flop_count.clear()
has_children = len(module._modules.items()) != 0
if not has_children:
module.__post_hook_handle__ = module.register_forward_hook(post_hook)
def start_time_hook(module, input):
module.__start_time__ = time.time()
module.__start_time_hook_handle__ = module.register_forward_pre_hook(
start_time_hook)
def end_time_hook(module, input, output):
module.__duration__ += time.time() - module.__start_time__
module.__end_time_hook_handle__ = module.register_forward_hook(end_time_hook)
self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
def end_profile(self):
"""Ends profiling.
Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored.
"""
def remove_profile_attrs(module):
if hasattr(module, "__steps__"):
del module.__steps__
if hasattr(module, "__flops__"):
del module.__flops__
if hasattr(module, "__params__"):
del module.__params__
if hasattr(module, "__start_time__"):
del module.__start_time__
if hasattr(module, "__duration__"):
del module.__duration__
if hasattr(module, "__pre_hook_handle__"):
module.__pre_hook_handle__.remove()
del module.__pre_hook_handle__
if hasattr(module, "__post_hook_handle__"):
module.__post_hook_handle__.remove()
del module.__post_hook_handle__
if hasattr(module, "__flops_handle__"):
module.__flops_handle__.remove()
del module.__flops_handle__
if hasattr(module, "__start_time_hook_handle__"):
module.__start_time_hook_handle__.remove()
del module.__start_time_hook_handle__
if hasattr(module, "__end_time_hook_handle__"):
module.__end_time_hook_handle__.remove()
del module.__end_time_hook_handle__
self.model.apply(remove_profile_attrs)
_reload_functionals()
def reset_profile(self):
"""Resets the profiling.
Adds or resets the extra attributes.
"""
def add_or_reset_attrs(module):
module.__flops__ = 0
module.__params__ = sum(p.numel() for p in module.parameters()
if p.requires_grad)
module.__start_time__ = 0
module.__duration__ = 0
module.__steps__ = 0
self.model.apply(add_or_reset_attrs)
def get_total_flops(self, in_str=False):
"""Returns the total flops of the model.
Args:
in_str (bool, optional): whether to output the flops in string. Defaults to False.
"""
if self.get_total_steps() == 0:
return 0
sum = 0
for module in self.model.modules():
sum += module.__flops__
total_flops = sum / self.get_total_steps()
return flops_to_string(total_flops) if in_str else total_flops
def get_total_duration(self, in_str=False):
"""Returns the total duration of the model forward pass.
Args:
in_str (bool, optional): whether to output the duration in string. Defaults to False.
"""
if self.get_total_steps() == 0:
return 0
total_duration = self.model.__duration__ / self.get_total_steps()
return duration_to_string(total_duration) if in_str else total_duration
def get_total_params(self, in_str=False):
"""Returns the total parameters of the model.
Args:
in_str (bool, optional): whether to output the parameters in string. Defaults to False.
"""
return params_to_string(
self.model.__params__) if in_str else self.model.__params__
def get_total_steps(self):
"""Returns the total number of steps (or input batches) profiled.
"""
def get_steps(module):
if module.__steps__ == 0:
sum = 0
for m in module.children():
sum += get_steps(m)
module.__steps__ = sum
return module.__steps__
total_steps = get_steps(self.model)
if total_steps == 0:
print("no step is profiled")
return total_steps
def print_model_profile(self):
"""Prints the model graph with the measured profile attached to each module.
"""
total_flops = self.get_total_flops()
total_duration = self.get_total_duration()
total_params = self.get_total_params()
total_steps = self.get_total_steps()
def accumulate_flops(module):
has_children = len(module._modules.items()) != 0
if not has_children:
return module.__flops__
else:
sum = 0
for m in module.children():
sum += m.accumulate_flops()
return sum
def flops_repr(module):
params = module.__params__
flops = 0 if total_steps == 0 else module.accumulate_flops() / total_steps
items = [
params_to_string(params),
"{:.2%} Params".format(params / total_params),
flops_to_string(flops),
"{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops),
]
duration = 0 if total_steps == 0 else module.__duration__ / total_steps
items.append(duration_to_string(duration))
items.append("{:.2%} time".format(0.0 if total_duration == 0 else duration /
total_duration))
# flops = 2 * MACs
items.append(("{:.2} TFLOPS".format(0.0 if duration == 0 else 2 * flops /
duration / 10**12)))
items.append(str(module.__steps__))
items.append(module.original_extra_repr())
return ", ".join(items)
def add_extra_repr(module):
module.accumulate_flops = accumulate_flops.__get__(module)
flops_extra_repr = flops_repr.__get__(module)
if module.extra_repr != flops_extra_repr:
module.original_extra_repr = module.extra_repr
module.extra_repr = flops_extra_repr
assert module.extra_repr != module.original_extra_repr
def del_extra_repr(module):
if hasattr(module, "original_extra_repr"):
module.extra_repr = module.original_extra_repr
del module.original_extra_repr
if hasattr(module, "accumulate_flops"):
del module.accumulate_flops
self.model.apply(add_extra_repr)
print(self.model)
self.model.apply(del_extra_repr)
def print_model_aggregated_profile(self, module_depth=-1, top_modules=3):
"""Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
Args:
module_depth (int, optional): the depth of the modules to show. Defaults to -1 (the innermost modules).
top_modules (int, optional): the number of top modules to show. Defaults to 3.
"""
info = {}
total_steps = self.get_total_steps()
if total_steps == 0:
return
if not hasattr(self.model, "__flops__"):
print(
"no __flops__ attribute in the model, call this function after start_profile and before end_profile"
)
return
def walk_module(module, curr_depth, info):
if curr_depth not in info:
info[curr_depth] = {}
if module.__class__.__name__ not in info[curr_depth]:
info[curr_depth][module.__class__.__name__] = [
0,
0,
0,
] # flops, params, time
info[curr_depth][module.__class__.__name__][0] += module.__flops__
info[curr_depth][module.__class__.__name__][1] += module.__params__
info[curr_depth][module.__class__.__name__][2] += (module.__duration__)
has_children = len(module._modules.items()) != 0
if has_children:
for child in module.children():
walk_module(child, curr_depth + 1, info)
walk_module(self.model, 0, info)
depth = module_depth
if module_depth == -1:
depth = len(info) - 1
num_items = min(top_modules, len(info[depth]))
sort_flops = {
k: flops_to_string(v[0] / total_steps)
for k,
v in sorted(info[depth].items(),
key=lambda item: item[1][0],
reverse=True)[:num_items]
}
sort_params = {
k: params_to_string(v[1])
for k,
v in sorted(info[depth].items(),
key=lambda item: item[1][1],
reverse=True)[:num_items]
}
sort_time = {
k: duration_to_string(v[2] / total_steps)
for k,
v in sorted(info[depth].items(),
key=lambda item: item[1][2],
reverse=True)[:num_items]
}
print(f"Top {num_items} modules in flops at depth {depth} are {sort_flops}")
print(f"Top {num_items} modules in params at depth {depth} are {sort_params}")
print(f"Top {num_items} modules in time at depth {depth} are {sort_time}")
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
def _linear_flops_compute(input, weight, bias=None):
out_features = weight.shape[0]
return torch.numel(input) * out_features
def _relu_flops_compute(input, inplace=False):
return torch.numel(input)
def _pool_flops_compute(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
return torch.numel(input)
def _conv_flops_compute(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
assert weight.shape[1] * groups == input.shape[1]
batch_size = input.shape[0]
in_channels = input.shape[1]
out_channels = weight.shape[0]
kernel_dims = list(weight.shape[-2:])
input_dims = list(input.shape[2:])
paddings = padding if type(padding) is tuple else (padding, padding)
strides = stride if type(stride) is tuple else (stride, stride)
dilations = dilation if type(dilation) is tuple else (dilation, dilation)
output_dims = [0, 0]
output_dims[0] = (input_dims[0] + 2 * paddings[0] -
(dilations[0] * (kernel_dims[0] - 1) + 1)) // strides[0] + 1
output_dims[1] = (input_dims[1] + 2 * paddings[1] -
(dilations[1] * (kernel_dims[1] - 1) + 1)) // strides[1] + 1
filters_per_channel = out_channels // groups
conv_per_position_flops = int(_prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(_prod(output_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if bias is not None:
bias_flops = out_channels * active_elements_count
overall_flops = overall_conv_flops + bias_flops
return int(overall_flops)
def _conv_trans_flops_compute(
input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
batch_size = input.shape[0]
in_channels = input.shape[1]
out_channels = weight.shape[0]
kernel_dims = list(weight.shape[-2:])
input_dims = list(input.shape[2:])
paddings = padding if type(padding) is tuple else (padding, padding)
strides = stride if type(stride) is tuple else (stride, stride)
dilations = dilation if type(dilation) is tuple else (dilation, dilation)
output_dims = [0, 0]
output_dims[0] = (input_dims[0] + 2 * paddings[0] -
(dilations[0] * (kernel_dims[0] - 1) + 1)) // strides[0] + 1
output_dims[1] = (input_dims[1] + 2 * paddings[1] -
(dilations[1] * (kernel_dims[1] - 1) + 1)) // strides[1] + 1
filters_per_channel = out_channels // groups
conv_per_position_flops = int(_prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(_prod(input_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if bias is not None:
bias_flops = out_channels * batch_size * int(_prod(output_dims))
overall_flops = overall_conv_flops + bias_flops
return int(overall_flops)
def _batch_norm_flops_compute(
input,
running_mean,
running_var,
weight=None,
bias=None,
training=False,
momentum=0.1,
eps=1e-05,
):
# assume affine is true
flops = 2 * torch.numel(input)
return flops
def _upsample_flops_compute(input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None):
if size is not None:
return int(_prod(size))
assert scale_factor is not None
flops = torch.numel(input)
if len(scale_factor) == len(input):
flops * int(_prod(scale_factor))
else:
flops * scale_factor**len(input)
return flops
def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None):
return torch.numel(input)
def _embedding_flops_compute(
input,
weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
):
return 0
def _dropout_flops_compute(input, p=0.5, training=True, inplace=False):
return 0
def wrapFunc(func, funcFlopCompute):
oldFunc = func
name = func.__name__
old_functions[func.__name__] = oldFunc
def newFunc(*args, **kwds):
flops = funcFlopCompute(*args, **kwds)
module_flop_count.append((name, flops))
return oldFunc(*args, **kwds)
return newFunc
def _patch_functionals():
# FC
F.linear = wrapFunc(F.linear, _linear_flops_compute)
# convolutions
F.conv1d = wrapFunc(F.conv1d, _conv_flops_compute)
F.conv2d = wrapFunc(F.conv2d, _conv_flops_compute)
F.conv3d = wrapFunc(F.conv3d, _conv_flops_compute)
# conv transposed
F.conv_transpose1d = wrapFunc(F.conv_transpose1d, _conv_trans_flops_compute)
F.conv_transpose2d = wrapFunc(F.conv_transpose2d, _conv_trans_flops_compute)
F.conv_transpose3d = wrapFunc(F.conv_transpose3d, _conv_trans_flops_compute)
# activations
F.relu = wrapFunc(F.relu, _relu_flops_compute)
F.prelu = wrapFunc(F.prelu, _relu_flops_compute)
F.elu = wrapFunc(F.elu, _relu_flops_compute)
F.leaky_relu = wrapFunc(F.leaky_relu, _relu_flops_compute)
F.relu6 = wrapFunc(F.relu6, _relu_flops_compute)
# BatchNorms
F.batch_norm = wrapFunc(F.batch_norm, _batch_norm_flops_compute)
# poolings
F.avg_pool1d = wrapFunc(F.avg_pool1d, _pool_flops_compute)
F.avg_pool2d = wrapFunc(F.avg_pool2d, _pool_flops_compute)
F.avg_pool3d = wrapFunc(F.avg_pool3d, _pool_flops_compute)
F.max_pool1d = wrapFunc(F.max_pool1d, _pool_flops_compute)
F.max_pool2d = wrapFunc(F.max_pool2d, _pool_flops_compute)
F.max_pool3d = wrapFunc(F.max_pool3d, _pool_flops_compute)
F.adaptive_avg_pool1d = wrapFunc(F.adaptive_avg_pool1d, _pool_flops_compute)
F.adaptive_avg_pool2d = wrapFunc(F.adaptive_avg_pool2d, _pool_flops_compute)
F.adaptive_avg_pool3d = wrapFunc(F.adaptive_avg_pool3d, _pool_flops_compute)
F.adaptive_max_pool1d = wrapFunc(F.adaptive_max_pool1d, _pool_flops_compute)
F.adaptive_max_pool2d = wrapFunc(F.adaptive_max_pool2d, _pool_flops_compute)
F.adaptive_max_pool3d = wrapFunc(F.adaptive_max_pool3d, _pool_flops_compute)
# upsample
F.upsample = wrapFunc(F.upsample, _upsample_flops_compute)
F.interpolate = wrapFunc(F.interpolate, _upsample_flops_compute)
# softmax
F.softmax = wrapFunc(F.softmax, _softmax_flops_compute)
# embedding
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
def _reload_functionals():
# torch.nn.functional does not support importlib.reload()
F.linear = old_functions["linear"]
F.conv1d = old_functions["conv1d"]
F.conv2d = old_functions["conv2d"]
F.conv3d = old_functions["conv3d"]
F.conv_transpose1d = old_functions["conv_transpose1d"]
F.conv_transpose2d = old_functions["conv_transpose2d"]
F.conv_transpose3d = old_functions["conv_transpose3d"]
F.relu = old_functions["relu"]
F.prelu = old_functions["prelu"]
F.elu = old_functions["elu"]
F.leaky_relu = old_functions["leaky_relu"]
F.relu6 = old_functions["relu6"]
F.batch_norm = old_functions["batch_norm"]
F.avg_pool1d = old_functions["avg_pool1d"]
F.avg_pool2d = old_functions["avg_pool2d"]
F.avg_pool3d = old_functions["avg_pool3d"]
F.max_pool1d = old_functions["max_pool1d"]
F.max_pool2d = old_functions["max_pool2d"]
F.max_pool3d = old_functions["max_pool3d"]
F.adaptive_avg_pool1d = old_functions["adaptive_avg_pool1d"]
F.adaptive_avg_pool2d = old_functions["adaptive_avg_pool2d"]
F.adaptive_avg_pool3d = old_functions["adaptive_avg_pool3d"]
F.adaptive_max_pool1d = old_functions["adaptive_max_pool1d"]
F.adaptive_max_pool2d = old_functions["adaptive_max_pool2d"]
F.adaptive_max_pool3d = old_functions["adaptive_max_pool3d"]
F.upsample = old_functions["upsample"]
F.interpolate = old_functions["interpolate"]
F.softmax = old_functions["softmax"]
F.embedding = old_functions["embedding"]
def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
# matrix matrix mult ih state and internal state
flops += w_ih.shape[0] * w_ih.shape[1]
# matrix matrix mult hh state and internal state
flops += w_hh.shape[0] * w_hh.shape[1]
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
# add both operations
flops += rnn_module.hidden_size
elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
# hadamard of r
flops += rnn_module.hidden_size
# adding operations from both states
flops += rnn_module.hidden_size * 3
# last two hadamard _product and add
flops += rnn_module.hidden_size * 3
elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
# adding operations from both states
flops += rnn_module.hidden_size * 4
# two hadamard _product and add for C state
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
# final hadamard
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
return flops
def _rnn_forward_hook(rnn_module, input, output):
flops = 0
# input is a tuple containing a sequence to process and (optionally) hidden state
inp = input[0]
batch_size = inp.shape[0]
seq_length = inp.shape[1]
num_layers = rnn_module.num_layers
for i in range(num_layers):
w_ih = rnn_module.__getattr__("weight_ih_l" + str(i))
w_hh = rnn_module.__getattr__("weight_hh_l" + str(i))
if i == 0:
input_size = rnn_module.input_size
else:
input_size = rnn_module.hidden_size
flops = _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
if rnn_module.bias:
b_ih = rnn_module.__getattr__("bias_ih_l" + str(i))
b_hh = rnn_module.__getattr__("bias_hh_l" + str(i))
flops += b_ih.shape[0] + b_hh.shape[0]
flops *= batch_size
flops *= seq_length
if rnn_module.bidirectional:
flops *= 2
rnn_module.__flops__ += int(flops)
def _rnn_cell_forward_hook(rnn_cell_module, input, output):
flops = 0
inp = input[0]
batch_size = inp.shape[0]
w_ih = rnn_cell_module.__getattr__("weight_ih")
w_hh = rnn_cell_module.__getattr__("weight_hh")
input_size = inp.shape[1]
flops = _rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
if rnn_cell_module.bias:
b_ih = rnn_cell_module.__getattr__("bias_ih")
b_hh = rnn_cell_module.__getattr__("bias_hh")
flops += b_ih.shape[0] + b_hh.shape[0]
flops *= batch_size
rnn_cell_module.__flops__ += int(flops)
MODULE_HOOK_MAPPING = {
# RNN
nn.RNN: _rnn_forward_hook,
nn.GRU: _rnn_forward_hook,
nn.LSTM: _rnn_forward_hook,
nn.RNNCell: _rnn_cell_forward_hook,
nn.LSTMCell: _rnn_cell_forward_hook,
nn.GRUCell: _rnn_cell_forward_hook,
}
def flops_to_string(flops, units=None, precision=2):
if units is None:
if flops // 10**9 > 0:
return str(round(flops / 10.0**9, precision)) + " GMACs"
elif flops // 10**6 > 0:
return str(round(flops / 10.0**6, precision)) + " MMACs"
elif flops // 10**3 > 0:
return str(round(flops / 10.0**3, precision)) + " KMACs"
else:
return str(flops) + " MACs"
else:
if units == "GMACs":
return str(round(flops / 10.0**9, precision)) + " " + units
elif units == "MMACs":
return str(round(flops / 10.0**6, precision)) + " " + units
elif units == "KMACs":
return str(round(flops / 10.0**3, precision)) + " " + units
else:
return str(flops) + " MACs"
def params_to_string(params_num, units=None, precision=2):
if units is None:
if params_num // 10**6 > 0:
return str(round(params_num / 10**6, 2)) + " M"
elif params_num // 10**3:
return str(round(params_num / 10**3, 2)) + " k"
else:
return str(params_num)
else:
if units == "M":
return str(round(params_num / 10.0**6, precision)) + " " + units
elif units == "K":
return str(round(params_num / 10.0**3, precision)) + " " + units
else:
return str(params_num)
def duration_to_string(duration, units=None, precision=2):
if units is None:
if duration > 1:
return str(round(duration, precision)) + " s"
elif duration * 10**3 > 1:
return str(round(duration * 10**3, precision)) + " ms"
elif duration * 10**6 > 1:
return str(round(duration * 10**6, precision)) + " us"
else:
return str(duration)
else:
if units == "us":
return str(round(duration * 10.0**6, precision)) + " " + units
elif units == "ms":
return str(round(duration * 10.0**3, precision)) + " " + units
else:
return str(round(duration, precision)) + " s"
def get_model_profile(
model,
input_res,
input_constructor=None,
print_profile=True,
print_aggregated_profile=True,
module_depth=-1,
top_modules=3,
warm_up=5,
num_steps=10,
as_strings=True,
ignore_modules=None,
):
"""Returns the total flops, parameters, and profiled steps of a model.
Args:
model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor
input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None.
print_profile (bool, optional): whether to print the model graph with the profile annotated. Defaults to True.
print_aggregated_profile (bool, optional): whether to print the aggregated profile for top modules. Defaults to True.
module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules).
top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3.
warm_up (int, optional): the number of warm-up steps before measuring the time of each module. Defaults to 5.
num_steps (int, optional): the number of steps to profile. Defaults to 10.
as_strings (bool, optional): whether to print the output as strings. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
"""
assert type(input_res) is tuple
assert len(input_res) >= 1
assert isinstance(model, nn.Module)
prof = FlopsProfiler(model)
model.eval()
for _ in range(warm_up):
if input_constructor:
input = input_constructor(input_res)
_ = model(**input)
else:
try:
batch = torch.ones(()).new_empty(
(*input_res),
dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device,
)
except StopIteration:
batch = torch.ones(()).new_empty((*input_res))
_ = model(batch)
prof.start_profile(ignore_list=ignore_modules)
for _ in range(num_steps):
if input_constructor:
input = input_constructor(input_res)
_ = model(**input)
else:
try:
batch = torch.ones(()).new_empty(
(*input_res),
dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device,
)
except StopIteration:
batch = torch.ones(()).new_empty((*input_res))
_ = model(batch)
flops = prof.get_total_flops()
params = prof.get_total_params()
steps = prof.get_total_steps()
if print_profile:
prof.print_model_profile()
if print_aggregated_profile:
prof.print_model_aggregated_profile(module_depth=module_depth,
top_modules=top_modules)
prof.end_profile()
if as_strings:
return flops_to_string(flops), params_to_string(params), steps
return flops, params, steps
...@@ -22,6 +22,8 @@ from ..elasticity.config import ElasticityConfigError ...@@ -22,6 +22,8 @@ from ..elasticity.config import ElasticityConfigError
from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \ from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT
from ..profiling.config import DeepSpeedFlopsProfilerConfig
TENSOR_CORE_ALIGN_SIZE = 8 TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam' ADAM_OPTIMIZER = 'adam'
...@@ -613,6 +615,7 @@ class DeepSpeedConfig(object): ...@@ -613,6 +615,7 @@ class DeepSpeedConfig(object):
self.scheduler_params = get_scheduler_params(param_dict) self.scheduler_params = get_scheduler_params(param_dict)
self.wall_clock_breakdown = get_wall_clock_breakdown(param_dict) self.wall_clock_breakdown = get_wall_clock_breakdown(param_dict)
self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
self.memory_breakdown = get_memory_breakdown(param_dict) self.memory_breakdown = get_memory_breakdown(param_dict)
self.tensorboard_enabled = get_tensorboard_enabled(param_dict) self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
self.tensorboard_output_path = get_tensorboard_output_path(param_dict) self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
......
...@@ -39,6 +39,8 @@ from ..ops.op_builder import UtilsBuilder ...@@ -39,6 +39,8 @@ from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam from ..ops.adam import DeepSpeedCPUAdam
from ..ops.adam import FusedAdam from ..ops.adam import FusedAdam
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
MEMORY_OPT_ALLREDUCE_SIZE = 500000000 MEMORY_OPT_ALLREDUCE_SIZE = 500000000
try: try:
...@@ -265,6 +267,21 @@ class DeepSpeedEngine(Module): ...@@ -265,6 +267,21 @@ class DeepSpeedEngine(Module):
def wall_clock_breakdown(self): def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown return self._config.wall_clock_breakdown
def flops_profiler_enabled(self):
return self._config.flops_profiler_config.enabled
def flops_profiler_start_step(self):
return self._config.flops_profiler_config.start_step
def flops_profiler_end_step(self):
return self._config.flops_profiler_config.end_step
def flops_profiler_module_depth(self):
return self._config.flops_profiler_config.module_depth
def flops_profiler_top_modules(self):
return self._config.flops_profiler_config.top_modules
def memory_breakdown(self): def memory_breakdown(self):
return self._config.memory_breakdown return self._config.memory_breakdown
...@@ -764,6 +781,30 @@ class DeepSpeedEngine(Module): ...@@ -764,6 +781,30 @@ class DeepSpeedEngine(Module):
*inputs: Variable length input list *inputs: Variable length input list
**kwargs: variable length keyword arguments **kwargs: variable length keyword arguments
""" """
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_start_step(
) and self.global_rank == 0:
self.flops_profiler = FlopsProfiler(self.module)
self.flops_profiler.start_profile(ignore_list=None)
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_end_step(
) and self.global_rank == 0:
print('{:<30} {:<8}'.format(
'Number of multiply-adds: ',
self.flops_profiler.get_total_flops(in_str=False)))
print('{:<30} {:<8}'.format(
'Number of parameters: ',
self.flops_profiler.get_total_params(in_str=False)))
print('{:<30} {:<8}'.format('Number of steps profiled: ',
self.flops_profiler.get_total_steps()))
self.flops_profiler.print_model_profile()
self.flops_profiler.print_model_aggregated_profile(
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules())
self.flops_profiler.flops = self.flops_profiler.get_total_flops()
self.flops_profiler.params = self.flops_profiler.get_total_params()
self.flops_profiler.end_profile()
if self.module.training and self.progressive_layer_drop: if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state()) kwargs.update(self.progressive_layer_drop.get_state())
......
import torch
import deepspeed
import deepspeed.runtime.utils as ds_utils
from deepspeed.profiling.flops_profiler import FlopsProfiler, get_model_profile
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
from common import distributed_test
def test_flops_profiler_in_ds_trainning(tmpdir):
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
}
},
"zero_optimization": {
"stage": 0
},
"fp16": {
"enabled": True,
},
"flops_profiler": {
"enabled": True,
"start_step": 2,
"end_step": 3,
"module_depth": -1,
"top_modules": 3,
},
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_flops_profiler_in_ds_trainning(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.half)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
if n == 3: break
assert model.flops_profiler.flops == 100
assert model.flops_profiler.params == 110
_test_flops_profiler_in_ds_trainning(args, model, hidden_dim)
class LeNet5(torch.nn.Module):
def __init__(self, n_classes):
super(LeNet5, self).__init__()
self.feature_extractor = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,
out_channels=6,
kernel_size=5,
stride=1),
torch.nn.Tanh(),
torch.nn.AvgPool2d(kernel_size=2),
torch.nn.Conv2d(in_channels=6,
out_channels=16,
kernel_size=5,
stride=1),
torch.nn.Tanh(),
torch.nn.AvgPool2d(kernel_size=2),
torch.nn.Conv2d(in_channels=16,
out_channels=120,
kernel_size=5,
stride=1),
torch.nn.Tanh(),
)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(in_features=120,
out_features=84),
torch.nn.Tanh(),
torch.nn.Linear(in_features=84,
out_features=n_classes),
)
def forward(self, x):
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
logits = self.classifier(x)
probs = torch.nn.functional.softmax(logits, dim=1)
return logits, probs
def test_flops_profiler_in_inference():
mod = LeNet5(10)
batch_size = 1024
input = torch.randn(batch_size, 1, 32, 32)
macs, params, steps = get_model_profile(
mod,
tuple(input.shape),
print_profile=True,
print_aggregated_profile=True,
module_depth=-1,
top_modules=3,
warm_up=5,
num_steps=10,
as_strings=True,
ignore_modules=None,
)
print(macs, params, steps)
assert macs == "439.55 MMACs"
assert params == "61.71 k"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment