README.rst 18.4 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10

    See LICENSE for license information.

|License|

Transformer Engine
==================

11
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
12
13

Latest News
Santosh Bhavani's avatar
Santosh Bhavani committed
14
===========
Santosh Bhavani's avatar
Santosh Bhavani committed
15

Santosh Bhavani's avatar
Santosh Bhavani committed
16
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_
17
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
18
19
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_

20
.. image:: docs/examples/H200-NeMo-performance.png
21
22
23
24
25
26
27
28
29
  :width: 600
  :alt: H200

* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
30
31

What is Transformer Engine?
Santosh Bhavani's avatar
Santosh Bhavani committed
32
===========================
Przemek Tredak's avatar
Przemek Tredak committed
33
34
35
.. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
36
37
38
39
40
41
using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better
performance with lower memory utilization in both training and inference. TE provides a collection
of highly optimized building blocks for popular Transformer architectures and an automatic mixed
precision-like API that can be used seamlessly with your framework-specific code. TE also includes a
framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8
support for Transformers.
Przemek Tredak's avatar
Przemek Tredak committed
42
43

As the number of parameters in Transformer models continues to grow, training and inference for
44
architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning
Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with
Santosh Bhavani's avatar
Santosh Bhavani committed
48
minimal differences in accuracy as compared to FP32 training. With Hopper GPU
Przemek Tredak's avatar
Przemek Tredak committed
49
50
architecture FP8 precision was introduced, which offers improved performance over FP16 with no
degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is
Santosh Bhavani's avatar
Santosh Bhavani committed
51
not available natively in frameworks today.
Przemek Tredak's avatar
Przemek Tredak committed
52
53

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Santosh Bhavani's avatar
Santosh Bhavani committed
54
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
55
56
57
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8
support. Modules provided by TE internally maintain scaling factors and other values needed for FP8
training, greatly simplifying mixed precision training for users.
Santosh Bhavani's avatar
Santosh Bhavani committed
58
59

Highlights
Santosh Bhavani's avatar
Santosh Bhavani committed
60
==========
Przemek Tredak's avatar
Przemek Tredak committed
61

62
63
* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
64
* Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs
Santosh Bhavani's avatar
Santosh Bhavani committed
65
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Ming-Xu Huang's avatar
Ming-Xu Huang committed
66
67

Examples
Santosh Bhavani's avatar
Santosh Bhavani committed
68
========
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69

Santosh Bhavani's avatar
Santosh Bhavani committed
70
PyTorch
Ming-Xu Huang's avatar
Ming-Xu Huang committed
71
^^^^^^^
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
75
76
77
78
79
80
81
82
83
84

.. code-block:: python

  import torch
  import transformer_engine.pytorch as te
  from transformer_engine.common import recipe

  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048

  # Initialize model and inputs.
nzmora-nvidia's avatar
nzmora-nvidia committed
85
  model = te.Linear(in_features, out_features, bias=True)
Przemek Tredak's avatar
Przemek Tredak committed
86
87
  inp = torch.randn(hidden_size, in_features, device="cuda")

Ming-Xu Huang's avatar
Ming-Xu Huang committed
88
  # Create an FP8 recipe. Note: All input args are optional.
89
  fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
Przemek Tredak's avatar
Przemek Tredak committed
90

Ming-Xu Huang's avatar
Ming-Xu Huang committed
91
  # Enable autocasting for the forward pass
Przemek Tredak's avatar
Przemek Tredak committed
92
93
94
95
96
97
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      out = model(inp)

  loss = out.sum()
  loss.backward()

Ming-Xu Huang's avatar
Ming-Xu Huang committed
98
99
100
101

JAX
^^^

102
103
104
Flax
~~~~

Ming-Xu Huang's avatar
Ming-Xu Huang committed
105
106
.. code-block:: python

107
  import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
108
109
110
  import jax
  import jax.numpy as jnp
  import transformer_engine.jax as te
111
  import transformer_engine.jax.flax as te_flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
112
113
114
115
116
117
118
119
120
121
122
123
  from transformer_engine.common import recipe

  BATCH = 32
  SEQLEN = 128
  HIDDEN = 1024

  # Initialize RNG and inputs.
  rng = jax.random.PRNGKey(0)
  init_rng, data_rng = jax.random.split(rng)
  inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

  # Create an FP8 recipe. Note: All input args are optional.
124
  fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
125
126
127

  # Enable autocasting for the forward pass
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
128
      model = te_flax.DenseGeneral(features=HIDDEN)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
129
130
131
132
133
134
135

      def loss_fn(params, other_vars, inp):
        out = model.apply({'params':params, **other_vars}, inp)
        return jnp.mean(out)

      # Initialize models.
      variables = model.init(init_rng, inp)
136
      other_variables, params = flax.core.pop(variables, 'params')
Ming-Xu Huang's avatar
Ming-Xu Huang committed
137
138
139
140
141
142
143

      # Construct the forward and backward function
      fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

      for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

Przemek Tredak's avatar
Przemek Tredak committed
144
145
146
.. overview-end-marker-do-not-remove

Installation
Santosh Bhavani's avatar
Santosh Bhavani committed
147
============
Przemek Tredak's avatar
Przemek Tredak committed
148

149
System Requirements
Przemek Tredak's avatar
Przemek Tredak committed
150
151
^^^^^^^^^^^^^^^^^^^^

152
153
154
155
156
157
158
159
160
161
162
163
* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere

* **OS:** Linux (official), WSL2 (limited support)

* **Software:**

  * CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers
  * cuDNN: 9.3+
  * Compiler: GCC 9+ or Clang 10+ with C++17 support
  * Python: 3.12 recommended

* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+
164

165
166
167
168
169
170
171
* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)

Installation Methods
^^^^^^^^^^^^^^^^^^^

Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
172
The quickest way to get started with Transformer Engine is by using Docker images on
173
174
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,
Przemek Tredak's avatar
Przemek Tredak committed
175

Santosh Bhavani's avatar
Santosh Bhavani committed
176
.. code-block:: bash
Przemek Tredak's avatar
Przemek Tredak committed
177

178
    docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3
179

180
Where 25.01 (corresponding to January 2025 release) is the container version.
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
**Benefits of using NGC containers:**

* All dependencies pre-installed with compatible versions and optimized configurations
* NGC PyTorch 23.08+ containers include FlashAttention-2

pip Installation
^^^^^^^^^^^^^^^^^^^

**Prerequisites for pip installation:**

* A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed

To install the latest stable version with pip:
196
197
198

.. code-block:: bash

199
200
201
202
203
204
205
206
207
208
209
210
    # For PyTorch integration
    pip install --no-build-isolation transformer_engine[pytorch]
    
    # For JAX integration
    pip install --no-build-isolation transformer_engine[jax]
    
    # For both frameworks
    pip install --no-build-isolation transformer_engine[pytorch,jax]

Alternatively, install directly from the GitHub repository:

.. code-block:: bash
211

212
    pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
213

214
When installing from GitHub, you can explicitly specify frameworks using the environment variable:
215
216
217

.. code-block:: bash

218
219
220
221
222
223
    NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

Source Installation
^^^^^^^^^^^^^^^^^^^

`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_
224

225
226
227
Environment Variables
^^^^^^^^^^^^^^^^^^^
These environment variables can be set before installation to customize the build process:
228

229
230
231
232
233
234
* **CUDA_PATH**: Path to CUDA installation
* **CUDNN_PATH**: Path to cuDNN installation
* **CXX**: Path to C++ compiler
* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``)
* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system)
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job
235

236
237
238
239
240
241
242
243
244
Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.

You can verify which FlashAttention version is being used by setting these environment variables:

.. code-block:: bash

    NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py
245

246
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
.. troubleshooting-begin-marker-do-not-remove
Troubleshooting
^^^^^^^^^^^^^^^^^^^

**Common Issues and Solutions:**

1. **ABI Compatibility Issues:**

   * **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine
   * **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI.
   * **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers.

2. **Missing Headers or Libraries:**

   * **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.)
   * **Solution:** Install missing development packages or set environment variables to point to correct locations:

     .. code-block:: bash

         export CUDA_PATH=/path/to/cuda
         export CUDNN_PATH=/path/to/cudnn

   * If CMake can't find a C++ compiler, set the ``CXX`` environment variable.
   * Ensure all paths are correctly set before installation.

3. **Build Resource Issues:**

   * **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors
   * **Solution:** Limit parallel builds:

     .. code-block:: bash

         MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ...

4. **Verbose Build Logging:**

   * For detailed build logs to help diagnose issues:

     .. code-block:: bash

         cd transformer_engine
         pip install -v -v -v --no-build-isolation .

.. troubleshooting-end-marker-do-not-remove
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
Breaking Changes
================

v1.7: Padding mask definition for PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In an effort to unify the definition and usage of the attention mask across all three frameworks in Transformer Engine, the padding mask has changed from `True` meaning inclusion of the corresponding position in attention to exclusion of that position in our PyTorch implementation. Since v1.7, all attention mask types follow the same definition where `True` means masking out the corresponding position and `False` means including that position in attention calculation.

An example of this change is,

.. code-block:: bash

    # for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
    # and `0`s are the padding tokens,
    [a, a, a, 0, 0,
     b, b, 0, 0, 0,
     c, c, c, c, 0]
    # the padding mask for this batch before v1.7 is,
    [ True,  True,  True, False, False,
      True,  True, False, False, False,
      True,  True,  True,  True, False]
    # and for v1.7 onwards it should be,
    [False, False, False,  True,  True,
     False, False,  True,  True,  True,
     False, False, False, False,  True]

318
FP8 Convergence
Santosh Bhavani's avatar
Santosh Bhavani committed
319
===============
Przemek Tredak's avatar
Przemek Tredak committed
320

321
322
323
324
325
326
327
328
329
330
331
FP8 has been tested extensively across different model architectures and configurations and we found **no significant difference** between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.

+------------+------------------+---------------------------------------------------------------------------------------------------------+
| Model      | Framework        | Source                                                                                                  |
+============+==================+=========================================================================================================+
| T5-770M    |  JAX/T5x         | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance|
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| MPT-1.3B   |  Mosaic Composer | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1                                              |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-5B     |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
332
| GPT-5B     |  NeMo Framework  | Available on request                                                                                    |
333
334
335
336
337
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-7B  |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| T5-11B     |  JAX/T5x         | Available on request                                                                                    |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
Santosh Bhavani's avatar
Santosh Bhavani committed
338
339
| MPT-13B    |  Mosaic Composer | https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8         |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
340
| GPT-22B    |  NeMo Framework  | Available on request                                                                                    |
341
342
343
344
345
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-70B |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-175B   |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
Przemek Tredak's avatar
Przemek Tredak committed
346

Santosh Bhavani's avatar
Santosh Bhavani committed
347
Integrations
Santosh Bhavani's avatar
Santosh Bhavani committed
348
============
Przemek Tredak's avatar
Przemek Tredak committed
349

350
Transformer Engine has been integrated with popular LLM frameworks such as:
Przemek Tredak's avatar
Przemek Tredak committed
351

352
353
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
354
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
355
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
356
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
357
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
358
* `NVIDIA NeMo Framework <https://github.com/NVIDIA/NeMo-Megatron-Launcher>`_
359
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features-v2-tensor-parallelism.html>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
360
* `Levanter <https://github.com/stanford-crfm/levanter>`_
361
* `GPT-NeoX <https://github.com/EleutherAI/gpt-neox>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
362
* `Hugging Face Nanotron <https://github.com/huggingface/nanotron>`_ - Coming soon!
363
364
365
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!

Przemek Tredak's avatar
Przemek Tredak committed
366

Santosh Bhavani's avatar
Santosh Bhavani committed
367
Contributing
Santosh Bhavani's avatar
Santosh Bhavani committed
368
============
Przemek Tredak's avatar
Przemek Tredak committed
369

Santosh Bhavani's avatar
Santosh Bhavani committed
370
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
371
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
Przemek Tredak's avatar
Przemek Tredak committed
372

Santosh Bhavani's avatar
Santosh Bhavani committed
373
Papers
Santosh Bhavani's avatar
Santosh Bhavani committed
374
======
Santosh Bhavani's avatar
Santosh Bhavani committed
375
376
377

* `Attention original paper <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`_
* `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
Przemek Tredak's avatar
Przemek Tredak committed
378
* `Megatron-LM sequence parallel <https://arxiv.org/pdf/2205.05198.pdf>`_
Kirthi Shankar Sivamani's avatar
Kirthi Shankar Sivamani committed
379
* `FP8 Formats for Deep Learning <https://arxiv.org/abs/2209.05433>`_
Przemek Tredak's avatar
Przemek Tredak committed
380

Santosh Bhavani's avatar
Santosh Bhavani committed
381
Videos
Santosh Bhavani's avatar
Santosh Bhavani committed
382
======
Santosh Bhavani's avatar
Santosh Bhavani committed
383

Santosh Bhavani's avatar
Santosh Bhavani committed
384
* `What's New in Transformer Engine and FP8 Training | GTC 2024 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
385
386
* `FP8 Training with Transformer Engine | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
387
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
388

Przemek Tredak's avatar
Przemek Tredak committed
389
390
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
   :target: https://opensource.org/licenses/Apache-2.0