Unverified Commit e0f36ed5 authored by Cheng Li's avatar Cheng Li Committed by GitHub
Browse files

Add optimizers and schedules to RTD and updated the corresponding part in the website (#799)



* add optimizers and schedules to rtd

* update ds website and fix links

* add optimizers and schedules to rtd

* update ds website and fix links

* add flops profiler to rtd

* fix
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
parent 29853c3e
...@@ -9,9 +9,6 @@ from deepspeed.profiling.constants import * ...@@ -9,9 +9,6 @@ from deepspeed.profiling.constants import *
class DeepSpeedFlopsProfilerConfig(object): class DeepSpeedFlopsProfilerConfig(object):
def __init__(self, param_dict): def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__() super(DeepSpeedFlopsProfilerConfig, self).__init__()
self.enabled = None self.enabled = None
...@@ -27,9 +24,6 @@ class DeepSpeedFlopsProfilerConfig(object): ...@@ -27,9 +24,6 @@ class DeepSpeedFlopsProfilerConfig(object):
self._initialize(flops_profiler_dict) self._initialize(flops_profiler_dict)
def _initialize(self, flops_profiler_dict): def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict, self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED, FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT) FLOPS_PROFILER_ENABLED_DEFAULT)
......
...@@ -12,6 +12,34 @@ class FlopsProfiler(object): ...@@ -12,6 +12,34 @@ class FlopsProfiler(object):
"""Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model. """Measures the latency, number of estimated floating point operations and parameters of each module in a PyTorch model.
The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input. The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the measured profile attached to each module. It shows how latency, flops and parameters are spent in the model and which modules or layers could be the bottleneck. It also outputs the names of the top k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified by the user. The output profile is computed for each batch of input.
The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file and no user code change is required.
If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs.
Here is an example for usage in a typical training workflow:
.. code-block:: python
model = Model()
prof = FlopsProfiler(model)
for step, batch in enumerate(data_loader):
if step == profile_step:
prof.start_profile()
loss = model(batch)
if step == profile_step:
flops = prof.get_total_flops(as_string=True)
params = prof.get_total_params(as_string=True)
prof.print_model_profile(profile_step=profile_step)
prof.end_profile()
loss.backward()
optimizer.step()
To profile a trained model in inference, use the `get_model_profile` API.
Args: Args:
object (torch.nn.Module): The PyTorch model to profile. object (torch.nn.Module): The PyTorch model to profile.
...@@ -118,6 +146,9 @@ class FlopsProfiler(object): ...@@ -118,6 +146,9 @@ class FlopsProfiler(object):
Args: Args:
as_string (bool, optional): whether to output the flops as string. Defaults to False. as_string (bool, optional): whether to output the flops as string. Defaults to False.
Returns:
The number of multiply-accumulate operations of the model forward pass.
""" """
total_flops = get_module_flops(self.model) total_flops = get_module_flops(self.model)
return macs_to_string(total_flops) if as_string else total_flops return macs_to_string(total_flops) if as_string else total_flops
...@@ -127,6 +158,9 @@ class FlopsProfiler(object): ...@@ -127,6 +158,9 @@ class FlopsProfiler(object):
Args: Args:
as_string (bool, optional): whether to output the duration as string. Defaults to False. as_string (bool, optional): whether to output the duration as string. Defaults to False.
Returns:
The latency of the model forward pass.
""" """
total_duration = self.model.__duration__ total_duration = self.model.__duration__
return duration_to_string(total_duration) if as_string else total_duration return duration_to_string(total_duration) if as_string else total_duration
...@@ -136,6 +170,9 @@ class FlopsProfiler(object): ...@@ -136,6 +170,9 @@ class FlopsProfiler(object):
Args: Args:
as_string (bool, optional): whether to output the parameters as string. Defaults to False. as_string (bool, optional): whether to output the parameters as string. Defaults to False.
Returns:
The number of parameters in the model.
""" """
return params_to_string( return params_to_string(
self.model.__params__) if as_string else self.model.__params__ self.model.__params__) if as_string else self.model.__params__
...@@ -146,6 +183,12 @@ class FlopsProfiler(object): ...@@ -146,6 +183,12 @@ class FlopsProfiler(object):
top_modules=3, top_modules=3,
detailed=True): detailed=True):
"""Prints the model graph with the measured profile attached to each module. """Prints the model graph with the measured profile attached to each module.
Args:
profile_step (int, optional): The global training step at which to profile. Note that warm up steps are needed for accurate time measurement.
module_depth (int, optional): The depth of the model at which to print the aggregated module information. When set to -1, it prints information on the innermost modules (with the maximum depth).
top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified.
detailed (bool, optional): Whether to print the detailed model profile.
""" """
total_flops = self.get_total_flops() total_flops = self.get_total_flops()
...@@ -219,7 +262,7 @@ class FlopsProfiler(object): ...@@ -219,7 +262,7 @@ class FlopsProfiler(object):
"\n------------------------------ Detailed Profile ------------------------------" "\n------------------------------ Detailed Profile ------------------------------"
) )
print( print(
"Each module profile is listed after its name in the follwing order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)." "Each module profile is listed after its name in the following order: \nnumber of parameters, percentage of total parameters, number of multiply-accumulate operations (MACs), percentage of total MACs, latency, percentage of total latency, number of floating point operations per second (FLOPS, computed as 2 * MACs / latency)."
) )
print( print(
"Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n" "Note: \n1. A module can have torch.nn.functional (e.g. to compute logits) along with submodules, thus making the difference between the parent's MACs(or latency) and the sum of its submodules'.\n2. Number of floating point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throught.\n"
...@@ -749,6 +792,14 @@ def get_model_profile( ...@@ -749,6 +792,14 @@ def get_model_profile(
): ):
"""Returns the total MACs and parameters of a model. """Returns the total MACs and parameters of a model.
Example:
.. code-block:: python
model = torchvision.models.alexnet()
batch_size = 256
macs, params = get_model_profile(model=model, input_res= (batch_size, 3, 224, 224)))
Args: Args:
model ([torch.nn.Module]): the PyTorch model to be profiled. model ([torch.nn.Module]): the PyTorch model to be profiled.
input_res (list): input shape or input to the input_constructor input_res (list): input shape or input to the input_constructor
...@@ -760,6 +811,9 @@ def get_model_profile( ...@@ -760,6 +811,9 @@ def get_model_profile(
warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1. warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1.
as_string (bool, optional): whether to print the output as string. Defaults to True. as_string (bool, optional): whether to print the output as string. Defaults to True.
ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None.
Returns:
The number of multiply-accumulate operations (MACs) and parameters in the model.
""" """
assert type(input_res) is tuple assert type(input_res) is tuple
assert len(input_res) >= 1 assert len(input_res) >= 1
......
...@@ -33,8 +33,8 @@ title: "DeepSpeed Configuration JSON" ...@@ -33,8 +33,8 @@ title: "DeepSpeed Configuration JSON"
***optimizer***: [dictionary] ***optimizer***: [dictionary]
| Fields | Value | Example | | Fields | Value | Example |
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | | ------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |
Example of ***optimizer*** with Adam Example of ***optimizer*** with Adam
...@@ -84,8 +84,8 @@ The Adam optimizer also supports the following two params keys/values in additio ...@@ -84,8 +84,8 @@ The Adam optimizer also supports the following two params keys/values in additio
***scheduler***: [dictionary] ***scheduler***: [dictionary]
| Fields | Value | Example | | Fields | Value | Example |
| ------ | ---------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------- | | ------ | -------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
| type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/deepspeed.pt.html) for list of support schedulers. | `"WarmupLR"` | | type | The scheduler name. See [here](https://deepspeed.readthedocs.io/en/latest/schedulers.html) for list of support schedulers. | `"WarmupLR"` |
| params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` | | params | Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. | `{"warmup_min_lr": 0, "warmup_max_lr": 0.001}` |
Example of ***scheduler*** Example of ***scheduler***
...@@ -164,7 +164,7 @@ Example of ***scheduler*** ...@@ -164,7 +164,7 @@ Example of ***scheduler***
***fp16:initial\_scale\_power***: [integer] ***fp16:initial\_scale\_power***: [integer]
| Description | Default | | Description | Default |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| ***initial\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2<sup>***initial\_scale\_power***</sup>. | `32` | | ***initial\_scale\_power*** is a **fp16** parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2<sup>***initial\_scale\_power***</sup>. | `32` |
***fp16:loss\_scale\_window***: [integer] ***fp16:loss\_scale\_window***: [integer]
......
...@@ -88,6 +88,8 @@ The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a stan ...@@ -88,6 +88,8 @@ The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a stan
- [Example: Bert](#example-bert) - [Example: Bert](#example-bert)
- [In Model Training Workflow](#in-model-training-workflow) - [In Model Training Workflow](#in-model-training-workflow)
- [Example Training Workflow](#example-training-workflow) - [Example Training Workflow](#example-training-workflow)
### Usage With the DeepSpeed Runtime ### Usage With the DeepSpeed Runtime
When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details. When using DeepSpeed for model training, the flops profiler can be configured in the `deepspeed_config` file. No explict API calls are needed to use the profiler. Refer to [flops profiler](https://www.deepspeed.ai/docs/config-json/#flops-profiler) for details.
......
DeepSpeedCPUAdam
################
.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam
:members:
Flops Profiler
==============
The flops profiler in DeepSpeed profiles the forward pass of a model and measures its parameters, latency, and floating point operations. The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package.
When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config file without user code changes. To use the flops profiler outside of the DeepSpeed runtime, one can simply install DeepSpeed and import the flops_profiler package to use the APIs directly.
Please see the `Flops Profiler tutorial <https://www.deepspeed.ai/tutorials/flops-profiler/>`_ for usage details.
Flops Profiler
---------------------------------------------------
.. automodule:: deepspeed.profiling.flops_profiler.profiler
:members:
:show-inheritance:
...@@ -33,7 +33,6 @@ ZeRO API ...@@ -33,7 +33,6 @@ ZeRO API
:maxdepth: 2 :maxdepth: 2
zero3 zero3
cpu-adam
...@@ -51,6 +50,26 @@ Pipeline Parallelism ...@@ -51,6 +50,26 @@ Pipeline Parallelism
pipeline pipeline
Optimizers
--------------------
.. toctree::
:maxdepth: 2
optimizers
Learning Rate Schedulers
--------------------
.. toctree::
:maxdepth: 2
schedulers
Flops Profiler
--------------------
.. toctree::
:maxdepth: 2
flops-profiler
Indices and tables Indices and tables
------------------ ------------------
......
Optimizers Optimizers
=================== ===================
DeepSpeed offers high-performance implementations of Adam and Lamb optimizers on CPU and GPU, respectively. DeepSpeed offers high-performance implementations of ``Adam`` optimizer on CPU; ``FusedAdam``, ``FusedAdam``, ``OneBitAdam`` optimizers on GPU.
DeepSpeed CPU Adam Adam (CPU)
---------------------------- ----------------------------
.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam .. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam
DeepSpeed Fused Lamb FusedAdam (GPU)
---------------------------- ----------------------------
.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam .. autoclass:: deepspeed.ops.adam.FusedAdam
FusedLamb (GPU)
----------------------------
.. autoclass:: deepspeed.ops.lamb.FusedLamb
OneBitAdam (GPU)
----------------------------
.. autoclass:: deepspeed.runtime.fp16.OneBitAdam
Learning Rate Schedulers
===================
DeepSpeed offers implementations of ``LRRangeTest``, ``OneCycle``, ``WarmupLR``, ``WarmupDecayLR`` learning rate schedulers.
LRRangeTest
---------------------------
.. autoclass:: deepspeed.runtime.lr_schedules.LRRangeTest
OneCycle
---------------------------
.. autoclass:: deepspeed.runtime.lr_schedules.OneCycle
WarmupLR
---------------------------
.. autoclass:: deepspeed.runtime.lr_schedules.WarmupLR
WarmupDecayLR
---------------------------
.. autoclass:: deepspeed.runtime.lr_schedules.WarmupDecayLR
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment