Unverified Commit e2dfe0d1 authored by Cheng Li's avatar Cheng Li Committed by GitHub
Browse files

Add flops profiler tutorial (#682)

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* fix tailing ws

* fix names

* remove multistep profiling and update docs

* fix cases where functionals and submodules coexist in a parent module, update readme

* fix typo

* always invoke post hook function

* fix module flops sum and update tests

* update tutorial
parent 6ee3b296
============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-6.0.1, py-1.9.0, pluggy-0.13.1
rootdir: /home/chengli1/projects/DeepSpeed
plugins: forked-1.3.0, hypothesis-5.41.3, xdist-2.1.0, cov-2.10.1
collected 0 items
============================ no tests ran in 0.01s =============================
...@@ -15,8 +15,7 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -15,8 +15,7 @@ class DeepSpeedFlopsProfilerConfig(object):
super(DeepSpeedFlopsProfilerConfig, self).__init__() super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None self.enabled = None
self.start_step = None self.profile_step = None
self.end_step = None
self.module_depth = None self.module_depth = None
self.top_modules = None self.top_modules = None
...@@ -35,13 +34,9 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -35,13 +34,9 @@ class DeepSpeedFlopsProfilerConfig(object):
FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT) FLOPS_PROFILER_ENABLED_DEFAULT)
self.start_step = get_scalar_param(flops_profiler_dict, self.profile_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_START_STEP, FLOPS_PROFILER_PROFILE_STEP,
FLOPS_PROFILER_START_STEP_DEFAULT) FLOPS_PROFILER_PROFILE_STEP_DEFAULT)
self.end_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_END_STEP,
FLOPS_PROFILER_END_STEP_DEFAULT)
self.module_depth = get_scalar_param(flops_profiler_dict, self.module_depth = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_MODULE_DEPTH, FLOPS_PROFILER_MODULE_DEPTH,
...@@ -50,3 +45,7 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -50,3 +45,7 @@ class DeepSpeedFlopsProfilerConfig(object):
self.top_modules = get_scalar_param(flops_profiler_dict, self.top_modules = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_TOP_MODULES, FLOPS_PROFILER_TOP_MODULES,
FLOPS_PROFILER_TOP_MODULES_DEFAULT) FLOPS_PROFILER_TOP_MODULES_DEFAULT)
self.detailed = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_DETAILED,
FLOPS_PROFILER_DETAILED_DEFAULT)
...@@ -12,11 +12,11 @@ FLOPS_PROFILER_FORMAT = ''' ...@@ -12,11 +12,11 @@ FLOPS_PROFILER_FORMAT = '''
flops profiler should be enabled as: flops profiler should be enabled as:
"session_params": { "session_params": {
"flops_profiler": { "flops_profiler": {
"enalbe": [true|false], "enabled": true,
"start_step": 5, "profile_step": 1,
"end_step": 6,
"module_depth": -1, "module_depth": -1,
"top_modules": 3, "top_modules": 3,
"detailed": true,
} }
} }
''' '''
...@@ -26,14 +26,14 @@ FLOPS_PROFILER = "flops_profiler" ...@@ -26,14 +26,14 @@ FLOPS_PROFILER = "flops_profiler"
FLOPS_PROFILER_ENABLED = "enabled" FLOPS_PROFILER_ENABLED = "enabled"
FLOPS_PROFILER_ENABLED_DEFAULT = False FLOPS_PROFILER_ENABLED_DEFAULT = False
FLOPS_PROFILER_START_STEP = "start_step" FLOPS_PROFILER_PROFILE_STEP = "profile_step"
FLOPS_PROFILER_START_STEP_DEFAULT = 5 FLOPS_PROFILER_PROFILE_STEP_DEFAULT = 1
FLOPS_PROFILER_END_STEP = "end_step"
FLOPS_PROFILER_END_STEP_DEFAULT = FLOPS_PROFILER_START_STEP_DEFAULT + 1
FLOPS_PROFILER_MODULE_DEPTH = "module_depth" FLOPS_PROFILER_MODULE_DEPTH = "module_depth"
FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1 FLOPS_PROFILER_MODULE_DEPTH_DEFAULT = -1
FLOPS_PROFILER_TOP_MODULES = "top_modules" FLOPS_PROFILER_TOP_MODULES = "top_modules"
FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3 FLOPS_PROFILER_TOP_MODULES_DEFAULT = 3
FLOPS_PROFILER_DETAILED = "detailed"
FLOPS_PROFILER_DETAILED_DEFAULT = True
# flops-profiler # DeepSpeed Flops Profiler
> Measures the time, number of estimated flops and parameters of each module in a PyTorch Model. > Measures the parameters, latency, and floating point operations of your model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken. - [Overview](#overview)
- [Supported Models](#supported-models)
- [Multi-GPU, Multi-node Runs](#multi-gpu-multi-node-runs)
- [Usage](#usage)
The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that flops-profiler captures `torch.nn.functional` invoked in a module to estimate the flops, thus allowing customized modules in the model (e.g. `ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). The flops-profiler also supports flops computation at module level (for RNNs). ## Overview
For models running on multi-node or multi-gpu, only the model parallelism affects the number of flops and parameters (e.g. `--model-parallel-size` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)), i.e., model_parallel_size _ flops = total_flops, model_parallel_size _ parameters = total_parameters. The number of gpus or nodes does not affect the output profile. The DeepSpeed flops profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module.
It shows the parameters, latency, and number of floating point operations of the modules within the model to identify potential bottlenecks.
It also outputs the names of the top `k` modules in terms of aggregated time, flops, and number of parameters at depth `l` with `k` and `l` specified by the user.
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
Below is an example output for LeNet5 with batch size 1024 on a V100 GPU: The output profile is computed for each batch of input and printed to the `stdout`. For each module, the measured profile is annotated after the name and is listed in the order of `number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency of the module, percentage of the total latency, floating point operations per second (FLOPS)`. Note that the number of floating point operations is estimated as `2 * MACs` in the profiler (each MAC operation is counted as 2 floating point operations).
Below is an example output for LeNet5 with batch size 1024:
```shell
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 1
Number of parameters: 61.71 k
Number of multiply-accumulate operations (MACs): 439.56 M
Number of floating point operations ( = 2 * MACs): 879.12 M
Latency: 25.7 ms
Floating point operations per second(FLOPS): 34.2 GFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'}
Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'}
Top 3 modules in latency at depth 2 are {'Conv2d': '11.37 ms', 'Linear': '5.27 ms', 'AvgPool2d': '5.02 ms'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
```
LeNet5( LeNet5(
61.71 k, 100.00% Params, 439.55 MMACs, 100.00% MACs, 25.62 ms, 100.00% time, 0.034 TFLOPS, 61.71 k, 100.00% Params, 439.56 MMACs, 100.00% MACs, 25.7 ms, 100.00% latency, 34.2 GFLOPS,
(feature_extractor): Sequential( (feature_extractor): Sequential(
50.69 k, 82.15% Params, 428.37 MMACs, 97.46% MACs, 18.41 ms, 71.85% time, 0.047 TFLOPS, 50.69 k, 82.15% Params, 428.37 MMACs, 97.45% MACs, 20.12 ms, 78.27% latency, 42.59 GFLOPS,
(0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 10.56 ms, 41.21% time, 0.024 TFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1)) (0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 9.8 ms, 38.12% latency, 25.56 GFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1))
(1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 2.25 ms, 8.79% time, 0.0 TFLOPS, ) (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 2.85 ms, 11.08% latency, 0.0 FLOPS, )
(2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 2.47 ms, 9.63% time, 0.0039 TFLOPS, kernel_size=2, stride=2, padding=0) (2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 4.01 ms, 15.59% latency, 2.4 GFLOPS, kernel_size=2, stride=2, padding=0)
(3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 1.08 ms, 4.23% time, 0.46 TFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1)) (3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 924.83 us, 3.60% latency, 535.02 GFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 497.39 us, 1.94% time, 0.0 TFLOPS, ) (4): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 672.1 us, 2.62% latency, 0.0 FLOPS, )
(5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 758.24 us, 2.96% time, 0.0043 TFLOPS, kernel_size=2, stride=2, padding=0) (5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 1.01 ms, 3.95% latency, 3.23 GFLOPS, kernel_size=2, stride=2, padding=0)
(6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 606.35 us, 2.37% time, 0.16 TFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1)) (6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 647.31 us, 2.52% latency, 152.25 GFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1))
(7): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 68.86 us, 0.27% time, 0.0 TFLOPS, ) (7): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 82.02 us, 0.32% latency, 0.0 FLOPS, )
) )
(classifier): Sequential( (classifier): Sequential(
11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 7.03 ms, 27.43% time, 0.0032 TFLOPS, 11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 5.41 ms, 21.06% latency, 4.13 GFLOPS,
(0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.71 ms, 10.57% time, 0.0076 TFLOPS, in_features=120, out_features=84, bias=True) (0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.47 ms, 9.60% latency, 8.37 GFLOPS, in_features=120, out_features=84, bias=True)
(1): Tanh(0, 0.00% Params, 0.0 MACs, 0.00% MACs, 78.77 us, 0.31% time, 0.0 TFLOPS, ) (1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 90.12 us, 0.35% latency, 0.0 FLOPS, )
(2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 4.17 ms, 16.27% time, 0.00041 TFLOPS, in_features=84, out_features=10, bias=True) (2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 2.8 ms, 10.91% latency, 613.62 MFLOPS, in_features=84, out_features=10, bias=True)
) )
) )
Top 3 modules in flops at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'} ------------------------------------------------------------------------------
Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'}
Top 3 modules in time at depth 2 are {'Conv2d': '12.25 ms', 'Linear': '6.88 ms', 'AvgPool2d': '3.23 ms'}
Batch size: 1024
Number of multiply-adds: 439.55 MMACs
Number of parameters: 61.71 k
Number of steps profiled: 10
``` ```
## Installation ## Supported Models
The profiler is an integral part of DeepSpeed and can be installed by The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that the DeepSpeed flops profiler captures ```torch.nn.functional``` invoked in a module to estimate the flops. Thus the DeepSpeed flops profiler allows for customized modules in the model, e.g., ```ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). This is in contrast to tools that profile at ```torch.nn.module``` level, such as ptflops, which require users to write customized flops calculation functions for each customized module. Finally, the DeepSpeed flops profiler also supports flops computation at module level (for RNNs).
``` ## Multi-GPU, Multi-node Runs
pip install deepspeed
``` For models running on multi-GPU or multi-node, only the model parallelism (e.g. ```--model-parallel-size``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)) affects the number of flops and parameters profiled, i.e.,
`model_parallel_size * flops = total_flops` and `model_parallel_size * parameters = total_parameters`. The number of GPUs or nodes does not affect the output profile.
Refer to the [installaiton of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for more information.
## Usage ## Usage
### With the DeepSpeed runtime The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file without user code changes. To use the flops profiler outside of the DeepSpeed runtime, one can simply install DeepSpeed and import the flops_profiler package to use the APIs directly. Examples of each usage are given below.
If using DeepSpeed for model training, no explict API calls are needed to use the flops-profiler. - [Usage With the DeepSpeed Runtime](#usage-with-the-deepspeed-runtime)
- [Example: Megatron-LM](#example-megatron-lm)
- [Usage Outside the DeepSpeed Runtime](#usage-outside-the-deepspeed-runtime)
- [In Model Inference](#in-model-inference)
- [Example: AlexNet](#example-alexnet)
- [Example: Bert](#example-bert)
- [In Model Training Workflow](#in-model-training-workflow)
- [Example Training Workflow](#example-training-workflow)
### Usage With the DeepSpeed Runtime
In DeepSpeed config file, specify: When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details.
```python
ds_config = { #### Example: Megatron-LM
...# other deepspeed configs
For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM)
The flops profiler can be enabled by adding the following field to the `deepspeed_config` file.
```json
{
"flops_profiler": { "flops_profiler": {
"enabled": True, "enabled": true,
"start_step": 2, "profile_step": 1,
"end_step": 3,
"module_depth": -1, "module_depth": -1,
"top_modules": 3, "top_modules": 3,
}, "detailed": true,
} }
}
``` ```
- `"enabled": true` to enable the flops-profiler.
- `"start_step": 5` to start the profiler at step 5. Note that warm-up is necessary for getting accurate timing information.
- `"end_step": 6` to end the profiler at step 6. Note that `end_step > start_step`.
- `"module_depth": -1` to print aggregated module information at the maximum depth (innermost modules). Can be set to any positive number, caped by the maximum depth of the model.
- `"top_modules": 3`to set the number of top modules to print aggregated profile
An example is given in [test_flops_profiler](tests/unit/test_flops_profiler.py).
### Without the DeepSpeed runtime
The flops-profiler can be used as a standalone package outside of the deepspeed runtime.
#### Use the low-level APIs to profile the forward pass in the existing model training workflow
- `start_profile` - starts profiling An example output of 4-layer Megatron-LM model (`hidden_size = 512, num_attention_heads = 16, batch_size = 8, seq_length = 1024`) is shown below.
- `get_total_flops` - returns the total number of flops
- `get_total_params` - returns the total number of params ```shell
- `get_total_duration` - returns the total duration of the model forward pass -------------------------- DeepSpeed Flops Profiler --------------------------
- `get_total_steps` - returns the total number of steps (or input batches) profiled. Summary of forward pass:
- `print_model_profile` - prints the profile annotated Profile step: 1
- `print_model_aggregated_profile` - prints the aggregated profile for the top modules Number of parameters: 38.89 M
- `end_profile` - ends profiling and cleans up, invoked at the end of the profiling and before any printing method. Number of multiply-accumulate operations (MACs): 314.61 G
Number of floating point operations ( = 2 * MACs): 629.21 G
`flops_to_string`, `params_to_string`, `duration_to_string` are utility functions to convert the metric number to string. Latency: 33.81 ms
Floating point operations per second(FLOPS): 18.61 TFLOPS
Below is an example of this usage in a typical training workflow.
----------------------------- Aggregated Profile -----------------------------
```python Top 3 modules in MACs at depth 8 are {'ColumnParallelLinear': '60.13 GMACs', 'RowParallelLinear': '42.95 GMACs', 'FusedScaleMaskSoftmax': '536.87 MMACs'}
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler Top 3 modules in params at depth 8 are {'ColumnParallelLinear': '7.35 M', 'RowParallelLinear': '5.25 M', 'FusedScaleMaskSoftmax': '0'}
Top 3 modules in latency at depth 8 are {'ColumnParallelLinear': '659.23 us', 'RowParallelLinear': '587.94 us', 'FusedScaleMaskSoftmax': '370.98 us'}
model = Model()
profiler = FlopsProfiler(model) ------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
start_step = 5 number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
end_step = 10 Note:
assert (end_step > start_step), "should end profiling after start profiling" 1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
print_profile = True 2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
pring_aggregated_profile = True
DistributedDataParallel(
for step, batch in enumerate(data_loader): 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.81 ms, 100.00% latency, 18.61 TFLOPS,
# start profiling at training step "profile_step" (module): FP16_Module(
if step == start_step: 38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.77 ms, 99.89% latency, 18.63 TFLOPS,
profiler.start_profile() (module): GPT2Model(
38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.69 ms, 99.66% latency, 18.67 TFLOPS,
# end profiling and print output at training step "profile_step" (language_model): TransformerLanguageModel(
if model == end_step: # if using multi nodes, check global_rank == 0 as well 38.89 M, 100.00% Params, 103.62 GMACs, 32.94% MACs, 5.58 ms, 16.51% latency, 37.13 TFLOPS,
flops = profiler.get_total_flops() (embedding): Embedding(
params = profiler.get_total_flops() 26.28 M, 67.57% Params, 0 MACs, 0.00% MACs, 545.98 us, 1.61% latency, 0.0 FLOPS,
duration = profiler.get_total_duration() (word_embeddings): VocabParallelEmbedding(25.76 M, 66.23% Params, 0 MACs, 0.00% MACs, 223.88 us, 0.66% latency, 0.0 FLOPS, )
steps = profiler.get_total_steps() (position_embeddings): Embedding(524.29 k, 1.35% Params, 0 MACs, 0.00% MACs, 147.1 us, 0.44% latency, 0.0 FLOPS, 1024, 512)
if print_profile: (embedding_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 79.39 us, 0.23% latency, 0.0 FLOPS, p=0.1, inplace=False)
profiler.print_model_profile() )
if print_aggregated_profile: (transformer): ParallelTransformer(
profiler.print_model_aggregated_profile(module_depth=-1, top_modules=3) 12.61 M, 32.43% Params, 103.62 GMACs, 32.94% MACs, 5.0 ms, 14.78% latency, 41.49 TFLOPS,
profiler.end_profile() (layers): ModuleList(
print(flops, params, duration, step) 12.61 M, 32.42% Params, 103.62 GMACs, 32.94% MACs, 4.4 ms, 13.01% latency, 47.13 TFLOPS,
(0): ParallelTransformerLayer(
3.15 M, 8.11% Params, 25.9 GMACs, 8.23% MACs, 1.36 ms, 4.02% latency, 38.09 TFLOPS,
(input_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 92.51 us, 0.27% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
(attention): ParallelSelfAttention(
1.05 M, 2.70% Params, 8.72 GMACs, 2.77% MACs, 754.59 us, 2.23% latency, 23.12 TFLOPS,
(query_key_value): ColumnParallelLinear(787.97 k, 2.03% Params, 6.44 GMACs, 2.05% MACs, 182.87 us, 0.54% latency, 70.46 TFLOPS, )
(scale_mask_softmax): FusedScaleMaskSoftmax(0, 0.00% Params, 134.22 MMACs, 0.04% MACs, 120.4 us, 0.36% latency, 2.23 TFLOPS, )
(attention_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 47.45 us, 0.14% latency, 0.0 FLOPS, p=0.1, inplace=False)
(dense): RowParallelLinear(262.66 k, 0.68% Params, 2.15 GMACs, 0.68% MACs, 81.78 us, 0.24% latency, 52.52 TFLOPS, )
)
(post_attention_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 57.22 us, 0.17% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
(mlp): ParallelMLP(
2.1 M, 5.40% Params, 17.18 GMACs, 5.46% MACs, 224.83 us, 0.67% latency, 152.83 TFLOPS,
(dense_h_to_4h): ColumnParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 64.13 us, 0.19% latency, 267.87 TFLOPS, )
(dense_4h_to_h): RowParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 90.36 us, 0.27% latency, 190.13 TFLOPS, )
)
)
...
(3): ParallelTransformerLayer(...)
(final_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 52.69 us, 0.16% latency, 0.0 TFLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
```
# forward() method ### Usage Outside the DeepSpeed Runtime
loss = model(batch)
# runs backpropagation The flops profiler can be used as a standalone package outside of the DeepSpeed runtime.
loss.backward() One can simply install DeepSpeed and import the `flops_profiler` package to use the APIs directly.
Refer to [installation of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for installing DeepSpeed.
# weight update #### In Model Inference
optimizer.step()
```
#### Use the high level-API and run the model inference for profiling purpose To profile a trained model in inference, use the `get_model_profile` function.
Examples are given below.
Examples of this usage are given below. ##### Example: AlexNet
##### Classification model example: The following example shows how to profile AlexNet using the DeepSpeed flops profiler.
```python ```python
import argparse
import sys
import torch
import torchvision.models as models import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile from deepspeed.profiling.flops_profiler import get_model_profile
pt_models = { with torch.cuda.device(0):
'resnet18': models.resnet18, model = models.alexnet()
'resnet50': models.resnet50,
'alexnet': models.alexnet,
'vgg16': models.vgg16,
'squeezenet': models.squeezenet1_0,
'densenet': models.densenet161,
'inception': models.inception_v3
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='flops-profiler example script')
parser.add_argument('--device',
type=int,
default=0,
help='Device to store the model.')
parser.add_argument('--model',
choices=list(pt_models.keys()),
type=str,
default='resnet18')
args = parser.parse_args()
model = pt_models[args.model]()
if torch.cuda.is_available():
model.cuda(device=args.device)
batch_size = 256 batch_size = 256
macs, params, steps = get_model_profile(model, # the PyTorch model to be profiled macs, params = get_model_profile(model=model, # model
input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor
input_constructor=None, # If specified, the constructor is applied to input_res and the constructor output is used as the input to the model input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
print_profile=True, # whether to print the model graph with the profile annotated. Defaults to True print_profile=True, # prints the model graph with the measured profile attached to each module
print_aggregated_profile=True, # whether to print the aggregated profile for top modules. Defaults to True detailed=True, # print the detailed profile
module_depth=-1, # the depth into the nested modules. Defaults to -1 (the inner most modules) module_depth=-1, # depth into the nested modules with -1 being the inner most modules
top_modules=3, # the number of top modules to print aggregated profile top_modules=3, # the number of top modules to print aggregated profile
warm_up=10, # the number of warm-up steps before measuring the time of each module. Defaults to 5 warm_up=10, # the number of warm-ups before measuring the time of each module
num_steps=10, # the number of steps to profile. Defaults to 10 as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
as_strings=True, # whether to print the output as strings (e.g. 1k). Defaults to True ignore_modules=None) # the list of modules to ignore in the profiling
ignore_modules=None) # the list of modules to ignore during profiling. Defaults to None ```
print("{:<30} {:<8}".format("Batch size: ", batch_size))
print('{:<30} {:<8}'.format('Number of MACs: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
print('{:<30} {:<8}'.format('Number of steps profiled: ', steps))
# Output:
# Number of MACs: 466.48 GMACs
# Number of parameters: 11.69 M
An example output:
```shell
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 10
Number of parameters: 61.1 M
Number of multiply-accumulate operations (MACs): 183.18 G
Number of floating point operations ( = 2 * MACs): 366.36 G
Latency: 22.13 ms
Floating point operations per second(FLOPS): 16.56 TFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 2 are {'Conv2d': '167.95 GMACs', 'Linear': '15.01 GMACs', 'ReLU': '126.26 MMACs'}
Top 3 modules in params at depth 2 are {'Linear': '58.63 M', 'Conv2d': '2.47 M', 'ReLU': '0'}
Top 3 modules in latency at depth 2 are {'Conv2d': '13.96 ms', 'Linear': '6.23 ms', 'ReLU': '730.75 us'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
AlexNet(
61.1 M, 100.00% Params, 183.18 GMACs, 100.00% MACs, 22.13 ms, 100.00% latency, 16.56 TFLOPS,
(features): Sequential(
2.47 M, 4.04% Params, 168.17 GMACs, 91.81% MACs, 15.17 ms, 68.57% latency, 22.17 TFLOPS,
(0): Conv2d(23.3 k, 0.04% Params, 18.04 GMACs, 9.85% MACs, 633.0 us, 2.86% latency, 57.0 TFLOPS, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 163.79 us, 0.74% latency, 605.17 GFLOPS, inplace=True)
(2): MaxPool2d(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 159.26 us, 0.72% latency, 622.38 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(307.39 k, 0.50% Params, 57.37 GMACs, 31.32% MACs, 6.15 ms, 27.81% latency, 18.64 TFLOPS, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 185.01 us, 0.84% latency, 387.34 GFLOPS, inplace=True)
(5): MaxPool2d(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 134.23 us, 0.61% latency, 533.89 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(663.94 k, 1.09% Params, 28.72 GMACs, 15.68% MACs, 389.58 us, 1.76% latency, 147.47 TFLOPS, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(0, 0.00% Params, 16.61 MMACs, 0.01% MACs, 76.53 us, 0.35% latency, 434.15 GFLOPS, inplace=True)
(8): Conv2d(884.99 k, 1.45% Params, 38.29 GMACs, 20.90% MACs, 6.38 ms, 28.82% latency, 12.01 TFLOPS, 384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 104.43 us, 0.47% latency, 212.12 GFLOPS, inplace=True)
(10): Conv2d(590.08 k, 0.97% Params, 25.53 GMACs, 13.94% MACs, 405.79 us, 1.83% latency, 125.83 TFLOPS, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 65.57 us, 0.30% latency, 337.85 GFLOPS, inplace=True)
(12): MaxPool2d(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 122.07 us, 0.55% latency, 181.46 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(0, 0.00% Params, 2.36 MMACs, 0.00% MACs, 259.4 us, 1.17% latency, 18.19 GFLOPS, output_size=(6, 6))
(classifier): Sequential(
58.63 M, 95.96% Params, 15.01 GMACs, 8.19% MACs, 6.54 ms, 29.54% latency, 4.59 TFLOPS,
(0): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 42.68 us, 0.19% latency, 0.0 FLOPS, p=0.5, inplace=False)
(1): Linear(37.75 M, 61.79% Params, 9.66 GMACs, 5.28% MACs, 301.36 us, 1.36% latency, 64.13 TFLOPS, in_features=9216, out_features=4096, bias=True)
(2): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 79.39 us, 0.36% latency, 26.41 GFLOPS, inplace=True)
(3): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 39.58 us, 0.18% latency, 0.0 FLOPS, p=0.5, inplace=False)
(4): Linear(16.78 M, 27.46% Params, 4.29 GMACs, 2.34% MACs, 234.37 us, 1.06% latency, 36.65 TFLOPS, in_features=4096, out_features=4096, bias=True)
(5): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 56.03 us, 0.25% latency, 37.43 GFLOPS, inplace=True)
(6): Linear(4.1 M, 6.71% Params, 1.05 GMACs, 0.57% MACs, 5.69 ms, 25.72% latency, 368.42 GFLOPS, in_features=4096, out_features=1000, bias=True)
)
)
------------------------------------------------------------------------------
``` ```
##### Bert model example: ##### Example: Bert
```python ```python
from functools import partial from functools import partial
import torch import torch
from transformers import BertForSequenceClassification, BertTokenizer from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile from deepspeed.profiling.flops_profiler import get_model_profile
def bert_input_constructor(input_shape, tokenizer): def bert_input_constructor(input_shape, tokenizer):
inp_seq = "" fake_seq = ""
for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP] for _ in range(input_shape[1] - 2): # ignore the two special tokens [CLS] and [SEP]
inp_seq += tokenizer.pad_token # let's use pad token to form a fake fake_seq += tokenizer.pad_token
# sequence for subsequent flops calculation inputs = tokenizer([fake_seq] * input_shape[0],
inputs = tokenizer([inp_seq] * input_shape[0],
padding=True, padding=True,
truncation=True, truncation=True,
return_tensors="pt") return_tensors="pt")
labels = torch.tensor([1] * input_shape[0]) labels = torch.tensor([1] * input_shape[0])
# Batch size input_shape[0], sequence length input_shape[128]
inputs = dict(inputs) inputs = dict(inputs)
inputs.update({"labels": labels}) inputs.update({"labels": labels})
return inputs return inputs
if __name__ == '__main__': with torch.cuda.device(0):
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
macs, params, steps = get_model_profile( batch_size = 4
seq_len = 128
enable_profile = True
if enable_profile:
macs, params = get_model_profile(
model, model,
(2, 128), (batch_size, seq_len),
input_constructor=partial(bert_input_constructor, input_constructor=partial(bert_input_constructor,
tokenizer=bert_tokenizer), tokenizer=tokenizer),
print_profile=True, print_profile=True,
print_aggregated_profile=True, detailed=True,
)
else:
inputs = bert_input_constructor((batch_size, seq_len), tokenizer)
outputs = model(inputs)
```
An example output:
```
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 1
Number of parameters: 109.48 M
Number of multiply-accumulate operations (MACs): 43.5 G
Number of floating point operations ( = 2 * MACs): 87.0 G
Latency: 393.7 ms
Floating point operations per second(FLOPS): 220.97 GFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 7 are {'Linear': '14.5 GMACs', 'Dropout': '0 MACs', 'LayerNorm': '0 MACs'}
Top 3 modules in params at depth 7 are {'Linear': '28.35 M', 'LayerNorm': '18.43 k', 'Dropout': '0'}
Top 3 modules in latency at depth 7 are {'Linear': '153.7 ms', 'LayerNorm': '4.74 ms', 'Dropout': '597.95 us'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
BertForSequenceClassification(
109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.7 ms, 100.00% latency, 220.97 GFLOPS,
(bert): BertModel(
109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.38 ms, 99.92% latency, 221.15 GFLOPS,
(embeddings): BertEmbeddings(
23.84 M, 21.77% Params, 0 MACs, 0.00% MACs, 1.79 ms, 0.45% latency, 0.0 FLOPS,
(word_embeddings): Embedding(23.44 M, 21.41% Params, 0 MACs, 0.00% MACs, 485.18 us, 0.12% latency, 0.0 FLOPS, 30522, 768, padding_idx=0)
(position_embeddings): Embedding(393.22 k, 0.36% Params, 0 MACs, 0.00% MACs, 111.1 us, 0.03% latency, 0.0 FLOPS, 512, 768)
(token_type_embeddings): Embedding(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 215.53 us, 0.05% latency, 0.0 FLOPS, 2, 768)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 386.95 us, 0.10% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 20.27 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
(encoder): BertEncoder(
85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 391.03 ms, 99.32% latency, 222.47 GFLOPS,
(layer): ModuleList(
85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 390.82 ms, 99.27% latency, 222.59 GFLOPS,
(0): BertLayer(
7.09 M, 6.47% Params, 3.62 GMACs, 8.33% MACs, 31.91 ms, 8.10% latency, 227.21 GFLOPS,
(attention): BertAttention(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 16.39 ms, 4.16% latency, 147.47 GFLOPS,
(self): BertSelfAttention(
1.77 M, 1.62% Params, 906.76 MMACs, 2.08% MACs, 15.07 ms, 3.83% latency, 120.36 GFLOPS,
(query): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.66 ms, 0.93% latency, 164.91 GFLOPS, in_features=768, out_features=768, bias=True)
(key): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.72 ms, 0.94% latency, 162.36 GFLOPS, in_features=768, out_features=768, bias=True)
(value): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 4.52 ms, 1.15% latency, 133.65 GFLOPS, in_features=768, out_features=768, bias=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 24.08 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
(output): BertSelfOutput(
592.13 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 1.29 ms, 0.33% latency, 469.21 GFLOPS,
(dense): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 504.26 us, 0.13% latency, 1.2 TFLOPS, in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 437.97 us, 0.11% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 21.93 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 9.57 ms, 2.43% latency, 252.35 GFLOPS,
(dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 8.75 ms, 2.22% latency, 276.11 GFLOPS, in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.77 ms, 1.47% latency, 418.39 GFLOPS,
(dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.13 ms, 1.30% latency, 471.15 GFLOPS, in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 310.9 us, 0.08% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 29.8 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
)
...
(11): BertLayer(...)
)
) )
print("{:<30} {:<8}".format("Number of multiply-adds: ", macs)) (pooler): BertPooler(
print("{:<30} {:<8}".format("Number of parameters: ", params)) 590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 337.12 us, 0.09% latency, 14.0 GFLOPS,
print("{:<30} {:<8}".format("Number of steps profiled: ", steps)) (dense): Linear(590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 173.57 us, 0.04% latency, 27.19 GFLOPS, in_features=768, out_features=768, bias=True)
(activation): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 46.01 us, 0.01% latency, 0.0 FLOPS, )
)
)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 19.55 us, 0.00% latency, 0.0 FLOPS, p=0.1, inplace=False)
(classifier): Linear(1.54 k, 0.00% Params, 6.14 KMACs, 0.00% MACs, 56.51 us, 0.01% latency, 217.47 MFLOPS, in_features=768, out_features=2, bias=True)
)
------------------------------------------------------------------------------
```
#### In Model Training Workflow
To profile model forward in a training workflow, use the `FlopsProfiler`class.
The `FlopsProfiler`class provides the follwing methods:
* `start_profile()` - starts profiling
* `get_total_flops(as_string=False)` - returns the total number of MACs in the model
* `get_total_params(as_string=False)` - returns the total number of parameters in the model
* `print_model_profile(profile_step=1, module_depth=-1, top_modules=3, detailed=True)` - prints the model profile
* `end_profile()` - ends profiling and cleans up. This should be invoked at the end of the profiling and AFTER `get_total_flops`, `get_total_params` or `print_model_profile`.
##### Example Training Workflow
Below is an example of this usage in a typical training workflow. Note that the flops profiler only captures the forward pass in a training step. The flops of a backward pass can be roughly estimated from that of the forward pass (~2x).
```python
from deepspeed.profiling.flops_profiler import FlopsProfiler
model = Model()
prof = FlopsProfiler(model)
profile_step = 5
print_profile= True
# Output: for step, batch in enumerate(data_loader):
# Number of multiply-adds: 21.74 GMACs # start profiling at training step "profile_step"
# Number of parameters: 109.48 M if step == profile_step:
prof.start_profile()
# forward() method
loss = model(batch)
# end profiling and print output
if step == profile_step: # if using multi nodes, check global_rank == 0 as well
flops = prof.get_total_flops(as_string=True)
params = prof.get_total_params(as_string=True)
if print_profile:
prof.print_model_profile(profile_step=profile_step)
prof.end_profile()
# runs backpropagation
loss.backward()
# weight update
optimizer.step()
``` ```
...@@ -9,9 +9,9 @@ old_functions = {} ...@@ -9,9 +9,9 @@ old_functions = {}
class FlopsProfiler(object): class FlopsProfiler(object):
"""Measures the time, number of estimated flops and parameters of each module in a PyTorch model. """Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how time, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated time, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. If multiple forward passes are specified by the user to caputre (in the case where the model have different paths or for more accurate timing), the average profile of the multiple batches is taken. The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
Args: Args:
object (torch.nn.Module): The PyTorch model to profile. object (torch.nn.Module): The PyTorch model to profile.
...@@ -42,20 +42,15 @@ class FlopsProfiler(object): ...@@ -42,20 +42,15 @@ class FlopsProfiler(object):
# if computing the flops of the functionals in a module # if computing the flops of the functionals in a module
def pre_hook(module, input): def pre_hook(module, input):
module_flop_count.clear() module_flop_count.append([])
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
module.__steps__ += 1
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
def post_hook(module, input, output): def post_hook(module, input, output):
module.__flops__ += sum([elem[1] for elem in module_flop_count]) if module_flop_count:
module_flop_count.clear() module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
module_flop_count.pop()
has_children = len(module._modules.items()) != 0
if not has_children:
module.__post_hook_handle__ = module.register_forward_hook(post_hook) module.__post_hook_handle__ = module.register_forward_hook(post_hook)
def start_time_hook(module, input): def start_time_hook(module, input):
...@@ -77,8 +72,6 @@ class FlopsProfiler(object): ...@@ -77,8 +72,6 @@ class FlopsProfiler(object):
Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored. Added attributes and handles are removed recursively on all the modules and the torch.nn.functionals are restored.
""" """
def remove_profile_attrs(module): def remove_profile_attrs(module):
if hasattr(module, "__steps__"):
del module.__steps__
if hasattr(module, "__flops__"): if hasattr(module, "__flops__"):
del module.__flops__ del module.__flops__
if hasattr(module, "__params__"): if hasattr(module, "__params__"):
...@@ -117,100 +110,91 @@ class FlopsProfiler(object): ...@@ -117,100 +110,91 @@ class FlopsProfiler(object):
if p.requires_grad) if p.requires_grad)
module.__start_time__ = 0 module.__start_time__ = 0
module.__duration__ = 0 module.__duration__ = 0
module.__steps__ = 0
self.model.apply(add_or_reset_attrs) self.model.apply(add_or_reset_attrs)
def get_total_flops(self, in_str=False): def get_total_flops(self, as_string=False):
"""Returns the total flops of the model. """Returns the total flops of the model.
Args: Args:
in_str (bool, optional): whether to output the flops in string. Defaults to False. as_string (bool, optional): whether to output the flops as string. Defaults to False.
""" """
if self.get_total_steps() == 0: total_flops = get_module_flops(self.model)
return 0 return macs_to_string(total_flops) if as_string else total_flops
sum = 0
for module in self.model.modules():
sum += module.__flops__
total_flops = sum / self.get_total_steps()
return flops_to_string(total_flops) if in_str else total_flops
def get_total_duration(self, in_str=False): def get_total_duration(self, as_string=False):
"""Returns the total duration of the model forward pass. """Returns the total duration of the model forward pass.
Args: Args:
in_str (bool, optional): whether to output the duration in string. Defaults to False. as_string (bool, optional): whether to output the duration as string. Defaults to False.
""" """
if self.get_total_steps() == 0: total_duration = self.model.__duration__
return 0 return duration_to_string(total_duration) if as_string else total_duration
total_duration = self.model.__duration__ / self.get_total_steps()
return duration_to_string(total_duration) if in_str else total_duration
def get_total_params(self, in_str=False): def get_total_params(self, as_string=False):
"""Returns the total parameters of the model. """Returns the total parameters of the model.
Args: Args:
in_str (bool, optional): whether to output the parameters in string. Defaults to False. as_string (bool, optional): whether to output the parameters as string. Defaults to False.
""" """
return params_to_string( return params_to_string(
self.model.__params__) if in_str else self.model.__params__ self.model.__params__) if as_string else self.model.__params__
def get_total_steps(self):
"""Returns the total number of steps (or input batches) profiled.
"""
def get_steps(module):
if module.__steps__ == 0:
sum = 0
for m in module.children():
sum += get_steps(m)
module.__steps__ = sum
return module.__steps__
total_steps = get_steps(self.model)
if total_steps == 0:
print("no step is profiled")
return total_steps
def print_model_profile(self): def print_model_profile(self,
profile_step=1,
module_depth=-1,
top_modules=3,
detailed=True):
"""Prints the model graph with the measured profile attached to each module. """Prints the model graph with the measured profile attached to each module.
""" """
total_flops = self.get_total_flops() total_flops = self.get_total_flops()
total_duration = self.get_total_duration() total_duration = self.get_total_duration()
total_params = self.get_total_params() total_params = self.get_total_params()
total_steps = self.get_total_steps()
def accumulate_flops(module): self.flops = total_flops
has_children = len(module._modules.items()) != 0 self.params = total_params
if not has_children:
return module.__flops__ print(
else: "\n-------------------------- DeepSpeed Flops Profiler --------------------------"
sum = 0 )
for m in module.children(): print("Summary of forward pass:")
sum += m.accumulate_flops() print('{:<30} {:<8}'.format('Profile step: ', profile_step))
return sum print('{:<30} {:<8}'.format('Number of parameters: ',
params_to_string(total_params)))
print('{:<30} {:<8}'.format('Number of multiply-accumulate operations (MACs): ',
num_to_string(total_flops)))
print('{:<30} {:<8}'.format(
'Number of floating point operations ( = 2 * MACs): ',
num_to_string(2 * total_flops)))
print('{:<30} {:<8}'.format('Latency: ', duration_to_string(total_duration)))
print('{:<30} {:<8}'.format('Floating point operations per second(FLOPS): ',
flops_to_string(2 * total_flops / total_duration)))
def flops_repr(module): def flops_repr(module):
params = module.__params__ params = module.__params__
flops = 0 if total_steps == 0 else module.accumulate_flops() / total_steps flops = get_module_flops(module)
items = [ items = [
params_to_string(params), params_to_string(params),
"{:.2%} Params".format(params / total_params), "{:.2%} Params".format(params / total_params),
flops_to_string(flops), macs_to_string(flops),
"{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops), "{:.2%} MACs".format(0.0 if total_flops == 0 else flops / total_flops),
] ]
duration = 0 if total_steps == 0 else module.__duration__ / total_steps duration = module.__duration__
if duration == 0: # e.g. ModuleList
for m in module.children():
duration += m.__duration__
items.append(duration_to_string(duration)) items.append(duration_to_string(duration))
items.append("{:.2%} time".format(0.0 if total_duration == 0 else duration / items.append(
"{:.2%} latency".format(0.0 if total_duration == 0 else duration /
total_duration)) total_duration))
# flops = 2 * MACs # flops = 2 * MACs
items.append(("{:.2} TFLOPS".format(0.0 if duration == 0 else 2 * flops / items.append(flops_to_string(0.0 if duration == 0 else 2 * flops / duration))
duration / 10**12)))
items.append(str(module.__steps__))
items.append(module.original_extra_repr()) items.append(module.original_extra_repr())
return ", ".join(items) return ", ".join(items)
def add_extra_repr(module): def add_extra_repr(module):
module.accumulate_flops = accumulate_flops.__get__(module)
flops_extra_repr = flops_repr.__get__(module) flops_extra_repr = flops_repr.__get__(module)
if module.extra_repr != flops_extra_repr: if module.extra_repr != flops_extra_repr:
module.original_extra_repr = module.extra_repr module.original_extra_repr = module.extra_repr
...@@ -221,13 +205,33 @@ class FlopsProfiler(object): ...@@ -221,13 +205,33 @@ class FlopsProfiler(object):
if hasattr(module, "original_extra_repr"): if hasattr(module, "original_extra_repr"):
module.extra_repr = module.original_extra_repr module.extra_repr = module.original_extra_repr
del module.original_extra_repr del module.original_extra_repr
if hasattr(module, "accumulate_flops"):
del module.accumulate_flops
self.model.apply(add_extra_repr) self.model.apply(add_extra_repr)
print(
"\n----------------------------- Aggregated Profile -----------------------------"
)
self.print_model_aggregated_profile(module_depth=module_depth,
top_modules=top_modules)
if detailed:
print(
"\n------------------------------ Detailed Profile ------------------------------"
)
print(
"Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
)
print(
"Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n"
)
print(self.model) print(self.model)
self.model.apply(del_extra_repr) self.model.apply(del_extra_repr)
print(
"------------------------------------------------------------------------------"
)
def print_model_aggregated_profile(self, module_depth=-1, top_modules=3): def print_model_aggregated_profile(self, module_depth=-1, top_modules=3):
"""Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth. """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters at depth module_depth.
...@@ -236,9 +240,6 @@ class FlopsProfiler(object): ...@@ -236,9 +240,6 @@ class FlopsProfiler(object):
top_modules (int, optional): the number of top modules to show. Defaults to 3. top_modules (int, optional): the number of top modules to show. Defaults to 3.
""" """
info = {} info = {}
total_steps = self.get_total_steps()
if total_steps == 0:
return
if not hasattr(self.model, "__flops__"): if not hasattr(self.model, "__flops__"):
print( print(
"no __flops__ attribute in the model, call this function after start_profile and before end_profile" "no __flops__ attribute in the model, call this function after start_profile and before end_profile"
...@@ -271,7 +272,7 @@ class FlopsProfiler(object): ...@@ -271,7 +272,7 @@ class FlopsProfiler(object):
num_items = min(top_modules, len(info[depth])) num_items = min(top_modules, len(info[depth]))
sort_flops = { sort_flops = {
k: flops_to_string(v[0] / total_steps) k: macs_to_string(v[0])
for k, for k,
v in sorted(info[depth].items(), v in sorted(info[depth].items(),
key=lambda item: item[1][0], key=lambda item: item[1][0],
...@@ -285,15 +286,15 @@ class FlopsProfiler(object): ...@@ -285,15 +286,15 @@ class FlopsProfiler(object):
reverse=True)[:num_items] reverse=True)[:num_items]
} }
sort_time = { sort_time = {
k: duration_to_string(v[2] / total_steps) k: duration_to_string(v[2])
for k, for k,
v in sorted(info[depth].items(), v in sorted(info[depth].items(),
key=lambda item: item[1][2], key=lambda item: item[1][2],
reverse=True)[:num_items] reverse=True)[:num_items]
} }
print(f"Top {num_items} modules in flops at depth {depth} are {sort_flops}") print(f"Top {num_items} modules in MACs at depth {depth} are {sort_flops}")
print(f"Top {num_items} modules in params at depth {depth} are {sort_params}") print(f"Top {num_items} modules in params at depth {depth} are {sort_params}")
print(f"Top {num_items} modules in time at depth {depth} are {sort_time}") print(f"Top {num_items} modules in latency at depth {depth} are {sort_time}")
def _prod(dims): def _prod(dims):
...@@ -461,7 +462,8 @@ def wrapFunc(func, funcFlopCompute): ...@@ -461,7 +462,8 @@ def wrapFunc(func, funcFlopCompute):
def newFunc(*args, **kwds): def newFunc(*args, **kwds):
flops = funcFlopCompute(*args, **kwds) flops = funcFlopCompute(*args, **kwds)
module_flop_count.append((name, flops)) if module_flop_count:
module_flop_count[-1].append((name, flops))
return oldFunc(*args, **kwds) return oldFunc(*args, **kwds)
return newFunc return newFunc
...@@ -630,25 +632,61 @@ MODULE_HOOK_MAPPING = { ...@@ -630,25 +632,61 @@ MODULE_HOOK_MAPPING = {
} }
def num_to_string(num, precision=2):
if num // 10**9 > 0:
return str(round(num / 10.0**9, precision)) + " G"
elif num // 10**6 > 0:
return str(round(num / 10.0**6, precision)) + " M"
elif num // 10**3 > 0:
return str(round(num / 10.0**3, precision)) + " K"
else:
return str(num)
def macs_to_string(macs, units=None, precision=2):
if units is None:
if macs // 10**9 > 0:
return str(round(macs / 10.0**9, precision)) + " GMACs"
elif macs // 10**6 > 0:
return str(round(macs / 10.0**6, precision)) + " MMACs"
elif macs // 10**3 > 0:
return str(round(macs / 10.0**3, precision)) + " KMACs"
else:
return str(macs) + " MACs"
else:
if units == "GMACs":
return str(round(macs / 10.0**9, precision)) + " " + units
elif units == "MMACs":
return str(round(macs / 10.0**6, precision)) + " " + units
elif units == "KMACs":
return str(round(macs / 10.0**3, precision)) + " " + units
else:
return str(macs) + " MACs"
def flops_to_string(flops, units=None, precision=2): def flops_to_string(flops, units=None, precision=2):
if units is None: if units is None:
if flops // 10**12 > 0:
return str(round(flops / 10.0**12, precision)) + " TFLOPS"
if flops // 10**9 > 0: if flops // 10**9 > 0:
return str(round(flops / 10.0**9, precision)) + " GMACs" return str(round(flops / 10.0**9, precision)) + " GFLOPS"
elif flops // 10**6 > 0: elif flops // 10**6 > 0:
return str(round(flops / 10.0**6, precision)) + " MMACs" return str(round(flops / 10.0**6, precision)) + " MFLOPS"
elif flops // 10**3 > 0: elif flops // 10**3 > 0:
return str(round(flops / 10.0**3, precision)) + " KMACs" return str(round(flops / 10.0**3, precision)) + " KFLOPS"
else: else:
return str(flops) + " MACs" return str(flops) + " FLOPS"
else: else:
if units == "GMACs": if units == "TFLOPS":
return str(round(flops / 10.0**12, precision)) + " " + units
if units == "GFLOPS":
return str(round(flops / 10.0**9, precision)) + " " + units return str(round(flops / 10.0**9, precision)) + " " + units
elif units == "MMACs": elif units == "MFLOPS":
return str(round(flops / 10.0**6, precision)) + " " + units return str(round(flops / 10.0**6, precision)) + " " + units
elif units == "KMACs": elif units == "KFLOPS":
return str(round(flops / 10.0**3, precision)) + " " + units return str(round(flops / 10.0**3, precision)) + " " + units
else: else:
return str(flops) + " MACs" return str(flops) + " FLOPS"
def params_to_string(params_num, units=None, precision=2): def params_to_string(params_num, units=None, precision=2):
...@@ -687,32 +725,40 @@ def duration_to_string(duration, units=None, precision=2): ...@@ -687,32 +725,40 @@ def duration_to_string(duration, units=None, precision=2):
return str(round(duration, precision)) + " s" return str(round(duration, precision)) + " s"
# can not iterate over all submodules using self.model.modules()
# since modules() returns duplicate modules only once
def get_module_flops(module):
sum = module.__flops__
# iterate over immediate children modules
for child in module.children():
sum += get_module_flops(child)
return sum
def get_model_profile( def get_model_profile(
model, model,
input_res, input_res,
input_constructor=None, input_constructor=None,
print_profile=True, print_profile=True,
print_aggregated_profile=True, detailed=True,
module_depth=-1, module_depth=-1,
top_modules=3, top_modules=3,
warm_up=5, warm_up=1,
num_steps=10, as_string=True,
as_strings=True,
ignore_modules=None, ignore_modules=None,
): ):
"""Returns the total flops, parameters, and profiled steps of a model. """Returns the total MACs and parameters of a model.
Args: Args:
model ([torch.nn.Module]): the PyTorch model to be profiled. model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor input_res (list): input shape or input to the input_constructor
input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None. input_constructor (func, optional): input constructor. If specified, the constructor is applied to input_res and the constructor output is used as the input to the model. Defaults to None.
print_profile (bool, optional): whether to print the model graph with the profile annotated. Defaults to True. print_profile (bool, optional): whether to print the model profile. Defaults to True.
print_aggregated_profile (bool, optional): whether to print the aggregated profile for top modules. Defaults to True. detailed (bool, optional): whether to print the detailed model profile. Defaults to True.
module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules). module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules).
top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3. top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3.
warm_up (int, optional): the number of warm-up steps before measuring the time of each module. Defaults to 5. warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
num_steps (int, optional): the number of steps to profile. Defaults to 10. as_string (bool, optional): whether to print the output as string. Defaults to True.
as_strings (bool, optional): whether to print the output as strings. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
""" """
assert type(input_res) is tuple assert type(input_res) is tuple
...@@ -738,7 +784,6 @@ def get_model_profile( ...@@ -738,7 +784,6 @@ def get_model_profile(
prof.start_profile(ignore_list=ignore_modules) prof.start_profile(ignore_list=ignore_modules)
for _ in range(num_steps):
if input_constructor: if input_constructor:
input = input_constructor(input_res) input = input_constructor(input_res)
_ = model(**input) _ = model(**input)
...@@ -756,14 +801,14 @@ def get_model_profile( ...@@ -756,14 +801,14 @@ def get_model_profile(
flops = prof.get_total_flops() flops = prof.get_total_flops()
params = prof.get_total_params() params = prof.get_total_params()
steps = prof.get_total_steps()
if print_profile: if print_profile:
prof.print_model_profile() prof.print_model_profile(profile_step=warm_up,
if print_aggregated_profile: module_depth=module_depth,
prof.print_model_aggregated_profile(module_depth=module_depth, top_modules=top_modules,
top_modules=top_modules) detailed=detailed)
prof.end_profile() prof.end_profile()
if as_strings: if as_string:
return flops_to_string(flops), params_to_string(params), steps return macs_to_string(flops), params_to_string(params)
return flops, params, steps return flops, params
...@@ -277,11 +277,8 @@ class DeepSpeedEngine(Module): ...@@ -277,11 +277,8 @@ class DeepSpeedEngine(Module):
def flops_profiler_enabled(self): def flops_profiler_enabled(self):
return self._config.flops_profiler_config.enabled return self._config.flops_profiler_config.enabled
def flops_profiler_start_step(self): def flops_profiler_profile_step(self):
return self._config.flops_profiler_config.start_step return self._config.flops_profiler_config.profile_step
def flops_profiler_end_step(self):
return self._config.flops_profiler_config.end_step
def flops_profiler_module_depth(self): def flops_profiler_module_depth(self):
return self._config.flops_profiler_config.module_depth return self._config.flops_profiler_config.module_depth
...@@ -289,6 +286,9 @@ class DeepSpeedEngine(Module): ...@@ -289,6 +286,9 @@ class DeepSpeedEngine(Module):
def flops_profiler_top_modules(self): def flops_profiler_top_modules(self):
return self._config.flops_profiler_config.top_modules return self._config.flops_profiler_config.top_modules
def flops_profiler_detailed(self):
return self._config.flops_profiler_config.detailed
def memory_breakdown(self): def memory_breakdown(self):
return self._config.memory_breakdown return self._config.memory_breakdown
...@@ -799,30 +799,11 @@ class DeepSpeedEngine(Module): ...@@ -799,30 +799,11 @@ class DeepSpeedEngine(Module):
**kwargs: variable length keyword arguments **kwargs: variable length keyword arguments
""" """
if self.flops_profiler_enabled( if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_start_step( ) and self.global_steps == self.flops_profiler_profile_step(
) and self.global_rank == 0: ) and self.global_rank == 0:
self.flops_profiler = FlopsProfiler(self.module) self.flops_profiler = FlopsProfiler(self.module)
self.flops_profiler.start_profile(ignore_list=None) self.flops_profiler.start_profile(ignore_list=None)
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_end_step(
) and self.global_rank == 0:
print('{:<30} {:<8}'.format(
'Number of multiply-adds: ',
self.flops_profiler.get_total_flops(in_str=False)))
print('{:<30} {:<8}'.format(
'Number of parameters: ',
self.flops_profiler.get_total_params(in_str=False)))
print('{:<30} {:<8}'.format('Number of steps profiled: ',
self.flops_profiler.get_total_steps()))
self.flops_profiler.print_model_profile()
self.flops_profiler.print_model_aggregated_profile(
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules())
self.flops_profiler.flops = self.flops_profiler.get_total_flops()
self.flops_profiler.params = self.flops_profiler.get_total_params()
self.flops_profiler.end_profile()
if self.module.training and self.progressive_layer_drop: if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state()) kwargs.update(self.progressive_layer_drop.get_state())
...@@ -838,6 +819,16 @@ class DeepSpeedEngine(Module): ...@@ -838,6 +819,16 @@ class DeepSpeedEngine(Module):
self.timers('forward').stop() self.timers('forward').stop()
self.timers('forward_microstep').stop() self.timers('forward_microstep').stop()
if self.flops_profiler_enabled(
) and self.global_steps == self.flops_profiler_profile_step(
) and self.global_rank == 0:
self.flops_profiler.print_model_profile(
profile_step=self.global_steps,
module_depth=self.flops_profiler_module_depth(),
top_modules=self.flops_profiler_top_modules(),
detailed=self.flops_profiler_detailed())
self.flops_profiler.end_profile()
return loss return loss
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
......
...@@ -41,6 +41,7 @@ collections: ...@@ -41,6 +41,7 @@ collections:
- 1Cycle.md - 1Cycle.md
- lrrt.md - lrrt.md
- zero.md - zero.md
- flops-profiler.md
defaults: defaults:
- scope: - scope:
......
...@@ -45,6 +45,8 @@ lnav: ...@@ -45,6 +45,8 @@ lnav:
url: /docs/config-json/#zero-optimizations-for-fp16-training url: /docs/config-json/#zero-optimizations-for-fp16-training
- title: "Logging" - title: "Logging"
url: /docs/config-json/#logging url: /docs/config-json/#logging
- title: "Flops Profiler"
url: /docs/config-json/#flops-profiler
- title: "Activation checkpointing" - title: "Activation checkpointing"
url: /docs/config-json/#activation-checkpointing url: /docs/config-json/#activation-checkpointing
- title: "Sparse Attention" - title: "Sparse Attention"
...@@ -84,5 +86,7 @@ lnav: ...@@ -84,5 +86,7 @@ lnav:
url: /tutorials/pipeline/ url: /tutorials/pipeline/
- title: "Progressive Layer Dropping" - title: "Progressive Layer Dropping"
url: /tutorials/progressive_layer_dropping/ url: /tutorials/progressive_layer_dropping/
- title: "Flops Profiler"
url: /tutorials/flops-profiler/
- title: "Contributing" - title: "Contributing"
url: /contributing/ url: /contributing/
...@@ -10,20 +10,20 @@ title: "DeepSpeed Configuration JSON" ...@@ -10,20 +10,20 @@ title: "DeepSpeed Configuration JSON"
***train\_batch\_size***: [integer] ***train\_batch\_size***: [integer]
| Value | Example | | Value | Example |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The effective training batch size. This is the amount of data samples that leads to one step of model update. ***train\_batch\_size*** is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., ***train\_step\_batch\_size***), the gradient accumulation steps (a.k.a., ***gradient\_accumulation\_steps***), and the number of GPUs. | `32` | | The effective training batch size. This is the amount of data samples that leads to one step of model update. ***train\_batch\_size*** is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., ***train\_step\_batch\_size***), the gradient accumulation steps (a.k.a., ***gradient\_accumulation\_steps***), and the number of GPUs. | `32` |
***train\_micro\_batch\_size\_per\_gpu***: [integer] ***train\_micro\_batch\_size\_per\_gpu***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ---------------------------- | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------ |
| Batch size to be processed by one GPU in one step (without gradient accumulation). When specified, ***gradient\_accumulation\_steps*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***gradient\_accumulation\_steps*** in the configuration JSON. | ***train\_batch\_size*** value | | Batch size to be processed by one GPU in one step (without gradient accumulation). When specified, ***gradient\_accumulation\_steps*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***gradient\_accumulation\_steps*** in the configuration JSON. | ***train\_batch\_size*** value |
***gradient\_accumulation\_steps***: [integer] ***gradient\_accumulation\_steps***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, ***train\_step\_batch\_size*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***train\_step\_batch\_size*** in the configuration JSON. | `1` | | Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. When specified, ***train\_step\_batch\_size*** is automatically calculated using ***train\_batch\_size*** and number of GPUs. Should not be concurrently specified with ***train\_step\_batch\_size*** in the configuration JSON. | `1` |
...@@ -33,7 +33,7 @@ title: "DeepSpeed Configuration JSON" ...@@ -33,7 +33,7 @@ title: "DeepSpeed Configuration JSON"
***optimizer***: [dictionary] ***optimizer***: [dictionary]
| Fields | Value | Example | | Fields | Value | Example |
| ------ | ------------------------------------------------------------ | ------------------------------ | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |
...@@ -56,7 +56,7 @@ title: "DeepSpeed Configuration JSON" ...@@ -56,7 +56,7 @@ title: "DeepSpeed Configuration JSON"
The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam): The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam):
| "params" key | Description | Default | | "params" key | Description | Default |
| ------------- | --------------------------------------------------------------------------- | --------| | ------------- | --------------------------------------------------------------------------- | ------- |
| torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | | torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false |
| adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | | adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true |
...@@ -84,7 +84,7 @@ The Adam optimizer also supports the following two params keys/values in additio ...@@ -84,7 +84,7 @@ The Adam optimizer also supports the following two params keys/values in additio
***scheduler***: [dictionary] ***scheduler***: [dictionary]
| Fields | Value | Example | | Fields | Value | Example |
| ------ | ------------------------------------------------------------ | ------------------------------ | | ------ | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
| type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/deepspeed.pt.html) for list of support schedulers. | `"WarmupLR"` | | type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/deepspeed.pt.html) for list of support schedulers. | `"WarmupLR"` |
| params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` | | params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` |
...@@ -106,7 +106,7 @@ Example of ***scheduler*** ...@@ -106,7 +106,7 @@ Example of ***scheduler***
***fp32\_allreduce***: [boolean] ***fp32\_allreduce***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------ | ------- | | -------------------------------------------------------------- | ------- |
| During gradient averaging perform allreduce with 32 bit values | `false` | | During gradient averaging perform allreduce with 32 bit values | `false` |
***prescale\_gradients***: [boolean] ***prescale\_gradients***: [boolean]
...@@ -118,13 +118,13 @@ Example of ***scheduler*** ...@@ -118,13 +118,13 @@ Example of ***scheduler***
***gradient_predivide_factor***: [float] ***gradient_predivide_factor***: [float]
| Description | Default | | Description | Default |
| ---------------------------- | ------- | | ------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0` | Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs | `1.0` |
***sparse\_gradients***: [boolean] ***sparse\_gradients***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` | | Enable sparse compression of [torch.nn.Embedding](https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding) gradients. | `false` |
### FP16 training options ### FP16 training options
...@@ -135,7 +135,7 @@ Example of ***scheduler*** ...@@ -135,7 +135,7 @@ Example of ***scheduler***
***fp16***: [dictionary] ***fp16***: [dictionary]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Configuration for using mixed precision/FP16 training that leverages [NVIDIA's Apex package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. NOTE: this does not use Apex's AMP mode that allows for more flexibility in mixed precision training modes, this mode is similar to AMP's O2 mode. Please see AMP support below if you want to use more complex mixed precision modes. If you want to use ZeRO (currently) you must use this mode. | None | | Configuration for using mixed precision/FP16 training that leverages [NVIDIA's Apex package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. NOTE: this does not use Apex's AMP mode that allows for more flexibility in mixed precision training modes, this mode is similar to AMP's O2 mode. Please see AMP support below if you want to use more complex mixed precision modes. If you want to use ZeRO (currently) you must use this mode. | None |
```json ```json
...@@ -152,37 +152,37 @@ Example of ***scheduler*** ...@@ -152,37 +152,37 @@ Example of ***scheduler***
***fp16:enabled***: [boolean] ***fp16:enabled***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------------------------- | ------- |
| ***enabled*** is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` | | ***enabled*** is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` |
***fp16:loss\_scale***: [float] ***fp16:loss\_scale***: [float]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| ***loss\_scale*** is a ***fp16*** parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. | `0.0` | | ***loss\_scale*** is a ***fp16*** parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. | `0.0` |
***fp16:initial\_scale\_power***: [integer] ***fp16:initial\_scale\_power***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| ***initial\_loss\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2<sup>***initial\_loss\_scale\_power***</sup>. | `32` | | ***initial\_loss\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2<sup>***initial\_loss\_scale\_power***</sup>. | `32` |
***fp16:loss\_scale\_window***: [integer] ***fp16:loss\_scale\_window***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------------------------------------------------------------- | ------- |
| ***loss\_scale\_window*** is a **fp16** parameter representing the window over which to raise/lower the dynamic loss scale value. | `1000` | | ***loss\_scale\_window*** is a **fp16** parameter representing the window over which to raise/lower the dynamic loss scale value. | `1000` |
***fp16:hysteresis***: [integer] ***fp16:hysteresis***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ---------------------------------------------------------------------------------------------- | ------- |
| ***hysteresis*** is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2` | | ***hysteresis*** is a **fp16** parameter representing the delay shift in dynamic loss scaling. | `2` |
***fp16:min\_loss\_scale***: [integer] ***fp16:min\_loss\_scale***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------------------------------------- | ------- |
| ***min\_loss\_scale*** is a **fp16** parameter representing the minimum dynamic loss scale value. | `1000` | | ***min\_loss\_scale*** is a **fp16** parameter representing the minimum dynamic loss scale value. | `1000` |
### Automatic mixed precision (AMP) training options ### Automatic mixed precision (AMP) training options
...@@ -193,7 +193,7 @@ Example of ***scheduler*** ...@@ -193,7 +193,7 @@ Example of ***scheduler***
***amp***: [dictionary] ***amp***: [dictionary]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Configuration for using automatic mixed precision (AMP) training that leverages [NVIDIA's Apex AMP package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. Is not compatible with `fp16` mode above or ZeRO. Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None | | Configuration for using automatic mixed precision (AMP) training that leverages [NVIDIA's Apex AMP package](https://nvidia.github.io/apex/). An example, including the available dictionary keys is illustrated below. Is not compatible with `fp16` mode above or ZeRO. Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None |
```json ```json
...@@ -208,13 +208,13 @@ Example of ***scheduler*** ...@@ -208,13 +208,13 @@ Example of ***scheduler***
***amp:enabled***: [boolean] ***amp:enabled***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ---------------------------------------------------------------------------------------- | ------- |
| ***enabled*** is an **amp** parameter indicating whether or not AMP training is enabled. | `false` | | ***enabled*** is an **amp** parameter indicating whether or not AMP training is enabled. | `false` |
***amp params***: [various] ***amp params***: [various]
| Description | Default | | Description | Default |
| ----------------------------------- | ------- | | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None | | Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the [apex.amp.initialize documentation](https://nvidia.github.io/apex/amp.html#apex.amp.initialize). | None |
### Gradient Clipping ### Gradient Clipping
...@@ -246,55 +246,55 @@ Enabling and configuring ZeRO memory optimizations ...@@ -246,55 +246,55 @@ Enabling and configuring ZeRO memory optimizations
***zero\_optimization***: [dictionary] ***zero\_optimization***: [dictionary]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------------------------------------- | ------- |
| Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` | | Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. | `false` |
***stage***: [integer] ***stage***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` | | Chooses different stages of ZeRO Optimizer. Stage 0, 1, and 2 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitiong, respectively. | `0` |
***allgather_partitions***: [boolean] ***allgather_partitions***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
| Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` | | Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step | `true` |
***allgather_bucket_size***: [boolean] ***allgather_bucket_size***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------ | ------- |
| Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `5e8` | | Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes | `5e8` |
***overlap_comm***: [boolean] ***overlap_comm***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ---------------------------------------------------------------------------- | ------- |
| Attempts to overlap the reduction of the gradients with backward computation | `false` | | Attempts to overlap the reduction of the gradients with backward computation | `false` |
***reduce_scatter***: [boolean] ***reduce_scatter***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ----------------------------------------------------------------------- | ------- |
| Uses reduce or reduce scatter instead of allreduce to average gradients | `true` | | Uses reduce or reduce scatter instead of allreduce to average gradients | `true` |
***reduce_bucket_size***: [boolean] ***reduce_bucket_size***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------- | ------- |
| Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `5e8` | | Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes | `5e8` |
***contiguous_gradients***: [boolean] ***contiguous_gradients***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` | | Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models. | `False` |
***cpu_offload***: [boolean] ***cpu_offload***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------------------------------------------------------------------ | ------- |
| Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` | | Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. | `False` |
...@@ -303,21 +303,63 @@ Enabling and configuring ZeRO memory optimizations ...@@ -303,21 +303,63 @@ Enabling and configuring ZeRO memory optimizations
***steps\_per\_print***: [integer] ***steps\_per\_print***: [integer]
| Description | Default | | Description | Default |
| ----------- | ------- | | ------------------------------ | ------- |
| Print train loss every N steps | `10` | | Print train loss every N steps | `10` |
***wall\_clock\_breakdown***: [boolean] ***wall\_clock\_breakdown***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ----------------------------------------------------------------------- | ------- |
| Enable timing of the latency of forward/backward/update training phases | `false` | | Enable timing of the latency of forward/backward/update training phases | `false` |
***dump_state***: [boolean] ***dump_state***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------- | ------- |
| Print out state information of DeepSpeed object after initialization | `false` | | Print out state information of DeepSpeed object after initialization | `false` |
### Flops Profiler
```json
{
"flops_profiler": {
"enabled": true,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true,
}
}
```
***enabled***: [boolean]
| Description | Default |
| --------------------------- | ------- |
| Enables the flops profiler. | `false` |
***profile\_step***: [integer]
| Description | Default |
| --------------------------------------------------------------------------------------------------------------- | ------- |
| The global training step at which to profile. Note that warm up steps are needed for accurate time measurement. | `1` |
***module\_depth***: [integer]
| Description | Default |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| The depth of the model at which to print the aggregated module information. When set to `-1`, it prints information on the innermost modules (with the maximum depth). | `-1` |
***top\_modules***: [integer]
| Description | Default |
| ---------------------------------------------------------------------------- | ------- |
| Limits the aggregated profile output to the number of top modules specified. | `3` |
***detailed***: [boolean]
| Description | Default |
| -------------------------------------------- | ------- |
| Whether to print the detailed model profile. | `true` |
### Activation Checkpointing ### Activation Checkpointing
```json ```json
"activation_checkpointing": { "activation_checkpointing": {
...@@ -332,39 +374,39 @@ Enabling and configuring ZeRO memory optimizations ...@@ -332,39 +374,39 @@ Enabling and configuring ZeRO memory optimizations
***partition\_activations***: [boolean] ***partition\_activations***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------- | ------- |
| Enables partition activation when used with model parallelism | `false` | | Enables partition activation when used with model parallelism | `false` |
***cpu\_checkpointing***: [boolean] ***cpu\_checkpointing***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------------------- | ------- |
| Offloads partitioned activations to CPU if partition_activations is enabled| `false` | | Offloads partitioned activations to CPU if partition_activations is enabled | `false` |
***contiguous\_memory\_optimization***: [boolean] ***contiguous\_memory\_optimization***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------- | ------- |
| Copies partitioned activations so that they are contiguous in memory | `false` | | Copies partitioned activations so that they are contiguous in memory | `false` |
***number_checkpoints***: [integer] ***number_checkpoints***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | -------------------------------------------------------------------------------------------------------- | ------- |
| Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` | | Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization | `None` |
***synchronize\_checkpoint\_boundary***: [boolean] ***synchronize\_checkpoint\_boundary***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------- | ------- |
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` | | Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
***profile***: [boolean] ***profile***: [boolean]
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | --------------------------------------------------------------- | ------- |
| Logs the forward and backward time for each checkpoint function | `false` | | Logs the forward and backward time for each checkpoint function | `false` |
### Sparse Attention ### Sparse Attention
...@@ -372,7 +414,7 @@ Enabling and configuring ZeRO memory optimizations ...@@ -372,7 +414,7 @@ Enabling and configuring ZeRO memory optimizations
***sparse\_attention***: [dictionary] ***sparse\_attention***: [dictionary]
| Fields | Value | Example | | Fields | Value | Example |
| ------ | ------------------------------------------------------------ | ------------------------------ | | -------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- |
| mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` | | mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` |
| block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 | | block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 |
| different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false | | different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false |
......
...@@ -240,19 +240,53 @@ comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed dat ...@@ -240,19 +240,53 @@ comes to data loading. Users simply provide a PyTorch dataset, and DeepSpeed dat
can automatically handle batch creation appropriately. can automatically handle batch creation appropriately.
## Performance Analysis and Debugging ## Performance Analysis and Debugging
For performance debugging, DeepSpeed can give you a detailed breakdown of the time spent
in different parts of the training by simply enabling it in the `deepspeed_config` DeepSpeed provides a set of tools for performance analysis and debugging.
file.
Please see the [core API doc](https://deepspeed.readthedocs.io/) for more details. ### Wall Clock Breakdown
DeepSpeed provides a detailed breakdown of the time spent
in different parts of the training.
This can be enabled by setting the following in the `deepspeed_config` file.
```json ```json
{ {
"wall_clock_breakdown": true, "wall_clock_breakdown": true,
}
```
### Timing Activiation Checkpoint Functions
When activiation checkpoingint is enabled, profiling the forward and backward time of each checkpoint function can be enabled in the `deepspeed_config` file.
```json
{
"activation_checkpointing": { "activation_checkpointing": {
"profile": true "profile": true
} }
} }
```
### Flops Profiler
The DeepSpeed flops profiler measures the time, flops and parameters of a PyTorch model and shows which modules or layers are the bottleneck. When used with the DeepSpeed runtime, the flops profiler can be configured in the `deepspeed_config` file as follows:
```json
{
"flops_profiler": {
"enabled": true,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true,
}
}
``` ```
The flops profiler can also be used as a standalone package. Please refer to the [Flops Profiler](/tutorials/flops-profiler) tutorial for more details.
## Sparse Attention ## Sparse Attention
DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial. DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial.
......
---
title: "Flops Profiler"
excerpt: "Measure the parameters, latency, and floating point operations of your model"
---
In this tutorial, we introduce the DeepSpeed flops profiler and provide examples of its usage.
- [Overview](#overview)
- [Supported Models](#supported-models)
- [Multi-GPU, Multi-node Runs](#multi-gpu-multi-node-runs)
- [Usage](#usage)
## Overview
The DeepSpeed flops profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module.
It shows the parameters, latency, and number of floating point operations of the modules within the model to identify potential bottlenecks.
It also outputs the names of the top `k` modules in terms of aggregated time, flops, and number of parameters at depth `l` with `k` and `l` specified by the user.
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
The output profile is computed for each batch of input and printed to the `stdout`. For each module, the measured profile is annotated after the name and is listed in the order of `number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency of the module, percentage of the total latency, floating point operations per second (FLOPS)`. Note that the number of floating point operations is estimated as `2 * MACs` in the profiler (each MAC operation is counted as 2 floating point operations).
Below is an example output for LeNet5 with batch size 1024:
```shell
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 1
Number of parameters: 61.71 k
Number of multiply-accumulate operations (MACs): 439.56 M
Number of floating point operations ( = 2 * MACs): 879.12 M
Latency: 25.7 ms
Floating point operations per second(FLOPS): 34.2 GFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 2 are {'Conv2d': '421.91 MMACs', 'Linear': '11.18 MMACs', 'AvgPool2d': '6.46 MMACs'}
Top 3 modules in params at depth 2 are {'Conv2d': '50.69 k', 'Linear': '11.01 k', 'Tanh': '0'}
Top 3 modules in latency at depth 2 are {'Conv2d': '11.37 ms', 'Linear': '5.27 ms', 'AvgPool2d': '5.02 ms'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
LeNet5(
61.71 k, 100.00% Params, 439.56 MMACs, 100.00% MACs, 25.7 ms, 100.00% latency, 34.2 GFLOPS,
(feature_extractor): Sequential(
50.69 k, 82.15% Params, 428.37 MMACs, 97.45% MACs, 20.12 ms, 78.27% latency, 42.59 GFLOPS,
(0): Conv2d(156, 0.25% Params, 125.24 MMACs, 28.49% MACs, 9.8 ms, 38.12% latency, 25.56 GFLOPS, 1, 6, kernel_size=(5, 5), stride=(1, 1))
(1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 2.85 ms, 11.08% latency, 0.0 FLOPS, )
(2): AvgPool2d(0, 0.00% Params, 4.82 MMACs, 1.10% MACs, 4.01 ms, 15.59% latency, 2.4 GFLOPS, kernel_size=2, stride=2, padding=0)
(3): Conv2d(2.42 k, 3.92% Params, 247.4 MMACs, 56.28% MACs, 924.83 us, 3.60% latency, 535.02 GFLOPS, 6, 16, kernel_size=(5, 5), stride=(1, 1))
(4): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 672.1 us, 2.62% latency, 0.0 FLOPS, )
(5): AvgPool2d(0, 0.00% Params, 1.64 MMACs, 0.37% MACs, 1.01 ms, 3.95% latency, 3.23 GFLOPS, kernel_size=2, stride=2, padding=0)
(6): Conv2d(48.12 k, 77.98% Params, 49.27 MMACs, 11.21% MACs, 647.31 us, 2.52% latency, 152.25 GFLOPS, 16, 120, kernel_size=(5, 5), stride=(1, 1))
(7): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 82.02 us, 0.32% latency, 0.0 FLOPS, )
)
(classifier): Sequential(
11.01 k, 17.85% Params, 11.18 MMACs, 2.54% MACs, 5.41 ms, 21.06% latency, 4.13 GFLOPS,
(0): Linear(10.16 k, 16.47% Params, 10.32 MMACs, 2.35% MACs, 2.47 ms, 9.60% latency, 8.37 GFLOPS, in_features=120, out_features=84, bias=True)
(1): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 90.12 us, 0.35% latency, 0.0 FLOPS, )
(2): Linear(850, 1.38% Params, 860.16 KMACs, 0.20% MACs, 2.8 ms, 10.91% latency, 613.62 MFLOPS, in_features=84, out_features=10, bias=True)
)
)
------------------------------------------------------------------------------
```
## Supported Models
The flops estimation is partly inspired by [ptflops](https://github.com/sovrasov/flops-counter.pytorch) with the major difference being that the DeepSpeed flops profiler captures ```torch.nn.functional``` invoked in a module to estimate the flops. Thus the DeepSpeed flops profiler allows for customized modules in the model, e.g., ```ParallelTransformerLayerworks, ParallelSelfAttention, RowParallelLinear, etc.``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). This is in contrast to tools that profile at ```torch.nn.module``` level, such as ptflops, which require users to write customized flops calculation functions for each customized module. Finally, the DeepSpeed flops profiler also supports flops computation at module level (for RNNs).
## Multi-GPU, Multi-node Runs
For models running on multi-GPU or multi-node, only the model parallelism (e.g. ```--model-parallel-size``` in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)) affects the number of flops and parameters profiled, i.e.,
`model_parallel_size * flops = total_flops` and `model_parallel_size * parameters = total_parameters`. The number of GPUs or nodes does not affect the output profile.
## Usage
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file without user code changes. To use the flops profiler outside of the DeepSpeed runtime, one can simply install DeepSpeed and import the flops_profiler package to use the APIs directly. Examples of each usage are given below.
- [Usage With the DeepSpeed Runtime](#usage-with-the-deepspeed-runtime)
- [Example: Megatron-LM](#example-megatron-lm)
- [Usage Outside the DeepSpeed Runtime](#usage-outside-the-deepspeed-runtime)
- [In Model Inference](#in-model-inference)
- [Example: AlexNet](#example-alexnet)
- [Example: Bert](#example-bert)
- [In Model Training Workflow](#in-model-training-workflow)
- [Example Training Workflow](#example-training-workflow)
### Usage With the DeepSpeed Runtime
When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details.
#### Example: Megatron-LM
For information on running Megatron-LM with DeepSpeed, please refer to our tutorial [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM)
The flops profiler can be enabled by adding the following field to the `deepspeed_config` file.
```json
{
"flops_profiler": {
"enabled": true,
"profile_step": 1,
"module_depth": -1,
"top_modules": 3,
"detailed": true,
}
}
```
An example output of 4-layer Megatron-LM model (`hidden_size = 512, num_attention_heads = 16, batch_size = 8, seq_length = 1024`) is shown below.
```shell
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 1
Number of parameters: 38.89 M
Number of multiply-accumulate operations (MACs): 314.61 G
Number of floating point operations ( = 2 * MACs): 629.21 G
Latency: 33.81 ms
Floating point operations per second(FLOPS): 18.61 TFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 8 are {'ColumnParallelLinear': '60.13 GMACs', 'RowParallelLinear': '42.95 GMACs', 'FusedScaleMaskSoftmax': '536.87 MMACs'}
Top 3 modules in params at depth 8 are {'ColumnParallelLinear': '7.35 M', 'RowParallelLinear': '5.25 M', 'FusedScaleMaskSoftmax': '0'}
Top 3 modules in latency at depth 8 are {'ColumnParallelLinear': '659.23 us', 'RowParallelLinear': '587.94 us', 'FusedScaleMaskSoftmax': '370.98 us'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
DistributedDataParallel(
38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.81 ms, 100.00% latency, 18.61 TFLOPS,
(module): FP16_Module(
38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.77 ms, 99.89% latency, 18.63 TFLOPS,
(module): GPT2Model(
38.89 M, 100.00% Params, 314.61 GMACs, 100.00% MACs, 33.69 ms, 99.66% latency, 18.67 TFLOPS,
(language_model): TransformerLanguageModel(
38.89 M, 100.00% Params, 103.62 GMACs, 32.94% MACs, 5.58 ms, 16.51% latency, 37.13 TFLOPS,
(embedding): Embedding(
26.28 M, 67.57% Params, 0 MACs, 0.00% MACs, 545.98 us, 1.61% latency, 0.0 FLOPS,
(word_embeddings): VocabParallelEmbedding(25.76 M, 66.23% Params, 0 MACs, 0.00% MACs, 223.88 us, 0.66% latency, 0.0 FLOPS, )
(position_embeddings): Embedding(524.29 k, 1.35% Params, 0 MACs, 0.00% MACs, 147.1 us, 0.44% latency, 0.0 FLOPS, 1024, 512)
(embedding_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 79.39 us, 0.23% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
(transformer): ParallelTransformer(
12.61 M, 32.43% Params, 103.62 GMACs, 32.94% MACs, 5.0 ms, 14.78% latency, 41.49 TFLOPS,
(layers): ModuleList(
12.61 M, 32.42% Params, 103.62 GMACs, 32.94% MACs, 4.4 ms, 13.01% latency, 47.13 TFLOPS,
(0): ParallelTransformerLayer(
3.15 M, 8.11% Params, 25.9 GMACs, 8.23% MACs, 1.36 ms, 4.02% latency, 38.09 TFLOPS,
(input_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 92.51 us, 0.27% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
(attention): ParallelSelfAttention(
1.05 M, 2.70% Params, 8.72 GMACs, 2.77% MACs, 754.59 us, 2.23% latency, 23.12 TFLOPS,
(query_key_value): ColumnParallelLinear(787.97 k, 2.03% Params, 6.44 GMACs, 2.05% MACs, 182.87 us, 0.54% latency, 70.46 TFLOPS, )
(scale_mask_softmax): FusedScaleMaskSoftmax(0, 0.00% Params, 134.22 MMACs, 0.04% MACs, 120.4 us, 0.36% latency, 2.23 TFLOPS, )
(attention_dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 47.45 us, 0.14% latency, 0.0 FLOPS, p=0.1, inplace=False)
(dense): RowParallelLinear(262.66 k, 0.68% Params, 2.15 GMACs, 0.68% MACs, 81.78 us, 0.24% latency, 52.52 TFLOPS, )
)
(post_attention_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 57.22 us, 0.17% latency, 0.0 FLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
(mlp): ParallelMLP(
2.1 M, 5.40% Params, 17.18 GMACs, 5.46% MACs, 224.83 us, 0.67% latency, 152.83 TFLOPS,
(dense_h_to_4h): ColumnParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 64.13 us, 0.19% latency, 267.87 TFLOPS, )
(dense_4h_to_h): RowParallelLinear(1.05 M, 2.70% Params, 8.59 GMACs, 2.73% MACs, 90.36 us, 0.27% latency, 190.13 TFLOPS, )
)
)
...
(3): ParallelTransformerLayer(...)
(final_layernorm): FusedLayerNorm(1.02 k, 0.00% Params, 0 MACs, 0.00% MACs, 52.69 us, 0.16% latency, 0.0 TFLOPS, torch.Size([512]), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
```
### Usage Outside the DeepSpeed Runtime
The flops profiler can be used as a standalone package outside of the DeepSpeed runtime.
One can simply install DeepSpeed and import the `flops_profiler` package to use the APIs directly.
Refer to [installation of DeepSpeed](https://www.deepspeed.ai/getting-started/#installation) for installing DeepSpeed.
#### In Model Inference
To profile a trained model in inference, use the `get_model_profile` function.
Examples are given below.
##### Example: AlexNet
The following example shows how to profile AlexNet using the DeepSpeed flops profiler.
```python
import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
with torch.cuda.device(0):
model = models.alexnet()
batch_size = 256
macs, params = get_model_profile(model=model, # model
input_res=(batch_size, 3, 224, 224), # input shape or input to the input_constructor
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
print_profile=True, # prints the model graph with the measured profile attached to each module
detailed=True, # print the detailed profile
module_depth=-1, # depth into the nested modules with -1 being the inner most modules
top_modules=3, # the number of top modules to print aggregated profile
warm_up=10, # the number of warm-ups before measuring the time of each module
as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
ignore_modules=None) # the list of modules to ignore in the profiling
```
An example output:
```shell
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 10
Number of parameters: 61.1 M
Number of multiply-accumulate operations (MACs): 183.18 G
Number of floating point operations ( = 2 * MACs): 366.36 G
Latency: 22.13 ms
Floating point operations per second(FLOPS): 16.56 TFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 2 are {'Conv2d': '167.95 GMACs', 'Linear': '15.01 GMACs', 'ReLU': '126.26 MMACs'}
Top 3 modules in params at depth 2 are {'Linear': '58.63 M', 'Conv2d': '2.47 M', 'ReLU': '0'}
Top 3 modules in latency at depth 2 are {'Conv2d': '13.96 ms', 'Linear': '6.23 ms', 'ReLU': '730.75 us'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
AlexNet(
61.1 M, 100.00% Params, 183.18 GMACs, 100.00% MACs, 22.13 ms, 100.00% latency, 16.56 TFLOPS,
(features): Sequential(
2.47 M, 4.04% Params, 168.17 GMACs, 91.81% MACs, 15.17 ms, 68.57% latency, 22.17 TFLOPS,
(0): Conv2d(23.3 k, 0.04% Params, 18.04 GMACs, 9.85% MACs, 633.0 us, 2.86% latency, 57.0 TFLOPS, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 163.79 us, 0.74% latency, 605.17 GFLOPS, inplace=True)
(2): MaxPool2d(0, 0.00% Params, 49.56 MMACs, 0.03% MACs, 159.26 us, 0.72% latency, 622.38 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(307.39 k, 0.50% Params, 57.37 GMACs, 31.32% MACs, 6.15 ms, 27.81% latency, 18.64 TFLOPS, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 185.01 us, 0.84% latency, 387.34 GFLOPS, inplace=True)
(5): MaxPool2d(0, 0.00% Params, 35.83 MMACs, 0.02% MACs, 134.23 us, 0.61% latency, 533.89 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(663.94 k, 1.09% Params, 28.72 GMACs, 15.68% MACs, 389.58 us, 1.76% latency, 147.47 TFLOPS, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(0, 0.00% Params, 16.61 MMACs, 0.01% MACs, 76.53 us, 0.35% latency, 434.15 GFLOPS, inplace=True)
(8): Conv2d(884.99 k, 1.45% Params, 38.29 GMACs, 20.90% MACs, 6.38 ms, 28.82% latency, 12.01 TFLOPS, 384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 104.43 us, 0.47% latency, 212.12 GFLOPS, inplace=True)
(10): Conv2d(590.08 k, 0.97% Params, 25.53 GMACs, 13.94% MACs, 405.79 us, 1.83% latency, 125.83 TFLOPS, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 65.57 us, 0.30% latency, 337.85 GFLOPS, inplace=True)
(12): MaxPool2d(0, 0.00% Params, 11.08 MMACs, 0.01% MACs, 122.07 us, 0.55% latency, 181.46 GFLOPS, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(0, 0.00% Params, 2.36 MMACs, 0.00% MACs, 259.4 us, 1.17% latency, 18.19 GFLOPS, output_size=(6, 6))
(classifier): Sequential(
58.63 M, 95.96% Params, 15.01 GMACs, 8.19% MACs, 6.54 ms, 29.54% latency, 4.59 TFLOPS,
(0): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 42.68 us, 0.19% latency, 0.0 FLOPS, p=0.5, inplace=False)
(1): Linear(37.75 M, 61.79% Params, 9.66 GMACs, 5.28% MACs, 301.36 us, 1.36% latency, 64.13 TFLOPS, in_features=9216, out_features=4096, bias=True)
(2): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 79.39 us, 0.36% latency, 26.41 GFLOPS, inplace=True)
(3): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 39.58 us, 0.18% latency, 0.0 FLOPS, p=0.5, inplace=False)
(4): Linear(16.78 M, 27.46% Params, 4.29 GMACs, 2.34% MACs, 234.37 us, 1.06% latency, 36.65 TFLOPS, in_features=4096, out_features=4096, bias=True)
(5): ReLU(0, 0.00% Params, 1.05 MMACs, 0.00% MACs, 56.03 us, 0.25% latency, 37.43 GFLOPS, inplace=True)
(6): Linear(4.1 M, 6.71% Params, 1.05 GMACs, 0.57% MACs, 5.69 ms, 25.72% latency, 368.42 GFLOPS, in_features=4096, out_features=1000, bias=True)
)
)
------------------------------------------------------------------------------
```
##### Example: Bert
```python
from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
def bert_input_constructor(input_shape, tokenizer):
fake_seq = ""
for _ in range(input_shape[1] - 2): # ignore the two special tokens [CLS] and [SEP]
fake_seq += tokenizer.pad_token
inputs = tokenizer([fake_seq] * input_shape[0],
padding=True,
truncation=True,
return_tensors="pt")
labels = torch.tensor([1] * input_shape[0])
inputs = dict(inputs)
inputs.update({"labels": labels})
return inputs
with torch.cuda.device(0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
batch_size = 4
seq_len = 128
enable_profile = True
if enable_profile:
macs, params = get_model_profile(
model,
(batch_size, seq_len),
input_constructor=partial(bert_input_constructor,
tokenizer=tokenizer),
print_profile=True,
detailed=True,
)
else:
inputs = bert_input_constructor((batch_size, seq_len), tokenizer)
outputs = model(inputs)
```
An example output:
```
-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step: 1
Number of parameters: 109.48 M
Number of multiply-accumulate operations (MACs): 43.5 G
Number of floating point operations ( = 2 * MACs): 87.0 G
Latency: 393.7 ms
Floating point operations per second(FLOPS): 220.97 GFLOPS
----------------------------- Aggregated Profile -----------------------------
Top 3 modules in MACs at depth 7 are {'Linear': '14.5 GMACs', 'Dropout': '0 MACs', 'LayerNorm': '0 MACs'}
Top 3 modules in params at depth 7 are {'Linear': '28.35 M', 'LayerNorm': '18.43 k', 'Dropout': '0'}
Top 3 modules in latency at depth 7 are {'Linear': '153.7 ms', 'LayerNorm': '4.74 ms', 'Dropout': '597.95 us'}
------------------------------ Detailed Profile ------------------------------
Each module profile is listed after its name in the follwing order:
number of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency).
Note:
1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.
2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.
BertForSequenceClassification(
109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.7 ms, 100.00% latency, 220.97 GFLOPS,
(bert): BertModel(
109.48 M, 100.00% Params, 43.5 GMACs, 100.00% MACs, 393.38 ms, 99.92% latency, 221.15 GFLOPS,
(embeddings): BertEmbeddings(
23.84 M, 21.77% Params, 0 MACs, 0.00% MACs, 1.79 ms, 0.45% latency, 0.0 FLOPS,
(word_embeddings): Embedding(23.44 M, 21.41% Params, 0 MACs, 0.00% MACs, 485.18 us, 0.12% latency, 0.0 FLOPS, 30522, 768, padding_idx=0)
(position_embeddings): Embedding(393.22 k, 0.36% Params, 0 MACs, 0.00% MACs, 111.1 us, 0.03% latency, 0.0 FLOPS, 512, 768)
(token_type_embeddings): Embedding(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 215.53 us, 0.05% latency, 0.0 FLOPS, 2, 768)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 386.95 us, 0.10% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 20.27 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
(encoder): BertEncoder(
85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 391.03 ms, 99.32% latency, 222.47 GFLOPS,
(layer): ModuleList(
85.05 M, 77.69% Params, 43.5 GMACs, 99.99% MACs, 390.82 ms, 99.27% latency, 222.59 GFLOPS,
(0): BertLayer(
7.09 M, 6.47% Params, 3.62 GMACs, 8.33% MACs, 31.91 ms, 8.10% latency, 227.21 GFLOPS,
(attention): BertAttention(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 16.39 ms, 4.16% latency, 147.47 GFLOPS,
(self): BertSelfAttention(
1.77 M, 1.62% Params, 906.76 MMACs, 2.08% MACs, 15.07 ms, 3.83% latency, 120.36 GFLOPS,
(query): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.66 ms, 0.93% latency, 164.91 GFLOPS, in_features=768, out_features=768, bias=True)
(key): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 3.72 ms, 0.94% latency, 162.36 GFLOPS, in_features=768, out_features=768, bias=True)
(value): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 4.52 ms, 1.15% latency, 133.65 GFLOPS, in_features=768, out_features=768, bias=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 24.08 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
(output): BertSelfOutput(
592.13 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 1.29 ms, 0.33% latency, 469.21 GFLOPS,
(dense): Linear(590.59 k, 0.54% Params, 301.99 MMACs, 0.69% MACs, 504.26 us, 0.13% latency, 1.2 TFLOPS, in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 437.97 us, 0.11% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 21.93 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 9.57 ms, 2.43% latency, 252.35 GFLOPS,
(dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 8.75 ms, 2.22% latency, 276.11 GFLOPS, in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.77 ms, 1.47% latency, 418.39 GFLOPS,
(dense): Linear(2.36 M, 2.16% Params, 1.21 GMACs, 2.78% MACs, 5.13 ms, 1.30% latency, 471.15 GFLOPS, in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm(1.54 k, 0.00% Params, 0 MACs, 0.00% MACs, 310.9 us, 0.08% latency, 0.0 FLOPS, (768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 29.8 us, 0.01% latency, 0.0 FLOPS, p=0.1, inplace=False)
)
)
...
(11): BertLayer(...)
)
)
(pooler): BertPooler(
590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 337.12 us, 0.09% latency, 14.0 GFLOPS,
(dense): Linear(590.59 k, 0.54% Params, 2.36 MMACs, 0.01% MACs, 173.57 us, 0.04% latency, 27.19 GFLOPS, in_features=768, out_features=768, bias=True)
(activation): Tanh(0, 0.00% Params, 0 MACs, 0.00% MACs, 46.01 us, 0.01% latency, 0.0 FLOPS, )
)
)
(dropout): Dropout(0, 0.00% Params, 0 MACs, 0.00% MACs, 19.55 us, 0.00% latency, 0.0 FLOPS, p=0.1, inplace=False)
(classifier): Linear(1.54 k, 0.00% Params, 6.14 KMACs, 0.00% MACs, 56.51 us, 0.01% latency, 217.47 MFLOPS, in_features=768, out_features=2, bias=True)
)
------------------------------------------------------------------------------
```
#### In Model Training Workflow
To profile model forward in a training workflow, use the `FlopsProfiler`class.
The `FlopsProfiler`class provides the follwing methods:
* `start_profile()` - starts profiling
* `get_total_flops(as_string=False)` - returns the total number of MACs in the model
* `get_total_params(as_string=False)` - returns the total number of parameters in the model
* `print_model_profile(profile_step=1, module_depth=-1, top_modules=3, detailed=True)` - prints the model profile
* `end_profile()` - ends profiling and cleans up. This should be invoked at the end of the profiling and AFTER `get_total_flops`, `get_total_params` or `print_model_profile`.
##### Example Training Workflow
Below is an example of this usage in a typical training workflow. Note that the flops profiler only captures the forward pass in a training step. The flops of a backward pass can be roughly estimated from that of the forward pass (~2x).
```python
from deepspeed.profiling.flops_profiler import FlopsProfiler
model = Model()
prof = FlopsProfiler(model)
profile_step = 5
print_profile= True
for step, batch in enumerate(data_loader):
# start profiling at training step "profile_step"
if step == profile_step:
prof.start_profile()
# forward() method
loss = model(batch)
# end profiling and print output
if step == profile_step: # if using multi nodes, check global_rank == 0 as well
flops = prof.get_total_flops(as_string=True)
params = prof.get_total_params(as_string=True)
if print_profile:
prof.print_model_profile(profile_step=profile_step)
prof.end_profile()
# runs backpropagation
loss.backward()
# weight update
optimizer.step()
```
...@@ -24,8 +24,7 @@ def test_flops_profiler_in_ds_trainning(tmpdir): ...@@ -24,8 +24,7 @@ def test_flops_profiler_in_ds_trainning(tmpdir):
}, },
"flops_profiler": { "flops_profiler": {
"enabled": True, "enabled": True,
"start_step": 2, "step": 1,
"end_step": 3,
"module_depth": -1, "module_depth": -1,
"top_modules": 3, "top_modules": 3,
}, },
...@@ -100,18 +99,17 @@ def test_flops_profiler_in_inference(): ...@@ -100,18 +99,17 @@ def test_flops_profiler_in_inference():
mod = LeNet5(10) mod = LeNet5(10)
batch_size = 1024 batch_size = 1024
input = torch.randn(batch_size, 1, 32, 32) input = torch.randn(batch_size, 1, 32, 32)
macs, params, steps = get_model_profile( macs, params = get_model_profile(
mod, mod,
tuple(input.shape), tuple(input.shape),
print_profile=True, print_profile=True,
print_aggregated_profile=True, detailed=True,
module_depth=-1, module_depth=-1,
top_modules=3, top_modules=3,
warm_up=5, warm_up=1,
num_steps=10, as_string=True,
as_strings=True,
ignore_modules=None, ignore_modules=None,
) )
print(macs, params, steps) print(macs, params)
assert macs == "439.55 MMACs" assert macs == "439.56 MMACs"
assert params == "61.71 k" assert params == "61.71 k"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment