Unverified Commit e5bbc2e5 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Sparse attn + ops/runtime refactor + v0.3.0 (#343)



* Sparse attn + ops/runtime refactor + v0.3.0
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent 838f53b7
...@@ -9,9 +9,9 @@ import torch ...@@ -9,9 +9,9 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import math import math
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
class FP16_UnfusedOptimizer(object): class FP16_UnfusedOptimizer(object):
......
...@@ -12,8 +12,8 @@ import argparse ...@@ -12,8 +12,8 @@ import argparse
from torch.optim import Optimizer from torch.optim import Optimizer
from typing import Union, List from typing import Union, List
import math import math
from deepspeed.pt.deepspeed_constants import * from deepspeed.runtime.constants import *
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
LR_SCHEDULE = 'lr_schedule' LR_SCHEDULE = 'lr_schedule'
LR_RANGE_TEST = 'LRRangeTest' LR_RANGE_TEST = 'LRRangeTest'
......
...@@ -9,7 +9,7 @@ Helper functions and classes from multiple sources. ...@@ -9,7 +9,7 @@ Helper functions and classes from multiple sources.
import torch import torch
from torch._six import inf from torch._six import inf
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
class CheckOverflow(object): class CheckOverflow(object):
......
...@@ -3,9 +3,8 @@ Copyright (c) Microsoft Corporation ...@@ -3,9 +3,8 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license. Licensed under the MIT license.
""" """
#from deepspeed.pt.deepspeed_constants import * from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.pt.deepspeed_config_utils import get_scalar_param from deepspeed.utils import logger
from deepspeed.pt.log_utils import logger
######################################### #########################################
# ZeRO optimization # ZeRO optimization
......
...@@ -4,11 +4,11 @@ import torch.distributed as dist ...@@ -4,11 +4,11 @@ import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict from collections import defaultdict
from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
from deepspeed.pt.log_utils import log_dist, logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.utils import get_grad_norm, CheckOverflow
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_OPTIMIZER_STATES from deepspeed.utils import logger, log_dist
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size): def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
......
...@@ -10,16 +10,15 @@ import math ...@@ -10,16 +10,15 @@ import math
from torch._six import inf from torch._six import inf
from torch.autograd import Variable from torch.autograd import Variable
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.utils import logger
#Toggle this to true to enable correctness test #Toggle this to true to enable correctness test
#with gradient partitioning and without #with gradient partitioning and without
pg_correctness_test = False pg_correctness_test = False
from deepspeed.pt.log_utils import logger
try: try:
from apex_C import flatten from apex_C import flatten
from apex_C import unflatten from apex_C import unflatten
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
def _initialize_parameter_parallel_groups(parameter_parallel_size=None): def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
......
from deepspeed.utils.logging import logger, log_dist
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
import psutil import psutil
import torch import torch
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
def print_rank_0(message): def print_rank_0(message):
......
...@@ -335,3 +335,43 @@ Enabling and configure ZeRO memory optimizations ...@@ -335,3 +335,43 @@ Enabling and configure ZeRO memory optimizations
| Description | Default | | Description | Default |
| ------------------------------------------------------------ | ------- | | ------------------------------------------------------------ | ------- |
| Logs the forward and backward time for each checkpoint function | `false` | | Logs the forward and backward time for each checkpoint function | `false` |
### Sparse Attention
***sparse\_attention***: [dictionary]
| Fields | Value | Example |
| ------ | ------------------------------------------------------------ | ------------------------------ |
| mode | A string determining sparsity structure type. Deepspeed currently supports `"dense"`, `"fixed"`, `"bigbird"`, `"bslongformer"`, and `"variable"`. | `"fixed"` |
| block | An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. | 16 |
| different\_layout\_per\_head | A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. | false |
| num\_local\_blocks | An integer determining the number of random blocks in each block row; only used in `"fixed"` mode. | 4 |
| num\_global\_blocks | An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in `"fixed"` and `"bigbird"` modes. | 1 |
| attention | A string determining attention type. Attention can be `"unidirectional"`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be `"bidirectional"`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in `"fixed"` and `"variable"` modes. | `"bidirectional"` |
| horizontal\_global\_attention | A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `"bidirectional"`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in `"fixed"` and `"variable"` modes. | false |
| num\_different\_global\_patterns | An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in `"fixed"` mode. | 4 |
| num\_random\_blocks | An integer determining the number of random blocks in each block row; used in `"variable"` and `"bigbird"` modes. | 0 |
| local\_window\_blocks | A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in `"variable"` mode. | [4] |
| global\_block\_indices | A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global\_block\_end\_indices parameter is set, this parameter is used as starting index of each global window; used in `"variable"` and `"bslongformer"` modes. | [0] |
| global\_block\_end\_indices | A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global\_block\_indices parameter, and combining this two parameters, for each index i, blocks from global\_block\_indices[i] to global\_block\_end\_indices[i], exclusive, are considered as global attention; used in `"variable"` and `"bslongformer"` modes. | None |
| num\_sliding\_window\_blocks | An integer determining the number of blocks in sliding local attention window; used in `"bigbird"` and `"bslongformer"` modes. | 3 |
Example of ***sparse\_attention***
```json
"sparse_attention": {
"mode": "fixed",
"block": 16,
"different_layout_per_head": true,
"num_local_blocks": 4,
"num_global_blocks": 1,
"attention": "bidirectional",
"horizontal_global_attention": false,
"num_different_global_patterns": 4,
"num_random_blocks": 0,
"local_window_blocks": [4],
"global_block_indices": [0],
"global_block_end_indices": None,
"num_sliding_window_blocks": 3
}
```
---
layout: single
title: "DeepSpeed Sparse Attention"
excerpt: ""
categories: news
new_post: true
date: 2020-09-09 01:00:00
---
Attention-based deep learning models such as the transformers are highly effective in capturing relationship between tokens in an input sequence, even across long distances. As a result, they are used with text, image, and sound-based inputs, where the sequence length can be in thousands of tokens. However, despite the effectiveness of attention modules to capture long term dependencies, in practice, their application to long sequence input is limited by compute and memory requirements of the attention computation that grow quadratically, `O(n^2)`, with the sequence length n.
To address this limitation, DeepSpeed offers a suite of sparse attention kernels --an instrumental technology that can reduce the compute and memory requirement of attention computation by orders-of-magnitude via block-sparse computation. The suite not only alleviates the memory bottleneck of attention calculation, but also performs sparse computation efficiently. Its APIs allow convenient integration with any transformer-based models. Along with providing a wide spectrum of sparsity structures, it has the flexibility of handling any user-defined block-sparse structures. More specifically, sparse attention (SA) can be designed to compute local attention between nearby tokens, or global attention via summary tokens computed with local attention. Moreover, SA can also allow random attention, or any combination of local, global, and random attention as shown in the following figure with blue, orange, and green blocks, respectively. As a result, SA decreases the memory footprint to `O(wn)`, in which `1 < w < n` is a parameter, whose value depends on the attention structure.
![Variable sparsity structure](/assets/images/sa_variable_sparsity_structure.png){: .align-center}
This library is PyTorch based and develops required kernels through [Triton](https://github.com/ptillet/triton) platform; kernels are not written in CUDA, which leaves the door open for CPU/OpenCL/Vulkan support in the future. The library is an extension to DeepSpeed and can be used through DeepSpeed as well as stand alone.
Block-sparse computations handled by DeepSpeed Sparse Attention kernels are illustrated in following figures for forward and backward passes respectively. In the figures, `S` stands for a `block-sparse matrix` and `D` a `dense matrix`.
![Sparse attention forward pass](/assets/images/sa_forward_pass.png){: .align-center}
![Sparse attention backward pass](/assets/images/sa_backward_pass.png){: .align-center}
To learn more about Sparsity Config, and also how to use this library, please check our [tutorial](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_tutorials/sparse_attention.md) that provides detailed information about it.
## Performance Results
* **Power over 10x longer sequences**
In a pre-training experiment, we ran BERT model under three settings: dense, dense with activation checkpoint, and sparse (SA) with activation checkpoint. SA empowers 10x and 16x longer sequences comparing with dense for BERT base and large, respectively. Following figure shows the longest sequence length runnable in BERT base and large model; experiment is performed with batch size 1 on a single Nvidia V100 GPU-32GB memory.
![Maximum sequence runnable on BERT](/assets/images/sa_maximum_sequence_runnable_on_bert.png){: .align-center}
* **up to 6.3x faster computation**
We continued the pre-training experiment for different batch sizes and sequence lengths, using [BERT base/large](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert) and [Megatron GPT2](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM). In this experiment we let the training to continue for 100 iteration and recorded the average time per last 30 iterations. SA reduces total computation comparing with dense and improves training speed: the boost is higher with increased sequence length and it is up to 6.3x faster for BERT base, 5.3x for BERT large, and 6.1x for GPT2. Following charts show these results.
![Training time for BERT base with varying sequence length](/assets/images/sa_bert_base_time_result.png){: .align-center}
![Training time for BERT large with varying sequence length](/assets/images/sa_bert_large_time_result.png){: .align-center}
![Training time for GPT2 with varying sequence length](/assets/images/sa_gpt2_time_result.png){: .align-center}
* **higher accuracy**
Related works along the line of sparse attention ([Sparse Transformer](https://arxiv.org/pdf/1904.10509.pdf), [Longformer](https://arxiv.org/pdf/2004.05150.pdf), [BigBird](https://arxiv.org/pdf/2007.14062.pdf)) have shown comparable or higher accuracy than full attention. Our experience is well aligned. In addition to lower memory overhead and faster computation, we also observe cases in production where SA reaches higher accuracy and faster convergence. The following chart illustrates accuracy of training a production model based on BERT for long document comprehension (2,048 sequence length). The experiment is performed in three settings: dense starting from scratch, SA starting from scratch, and SA continued training from a checkpoint of using dense with sequence length of 512. We have observed that, for pre-training from scratch, SA converges faster with higher accuracy comparing with dense. Furthermore, SA continuing training from a pre-trained checkpoint performs even better, with respect to both time and accuracy.
![Accuracy of long document comprehension application](/assets/images/sa_long_document_comprehension_result.png){: .align-center}
* **flexibility to handle any block-sparse structure**
DeepSpeed Sparse Attention suite does not target at any specific sparse structure but enables model scientists to explore any block sparse structure with efficient system support. Currently, we have added popular sparse structure like:
* [Fixed](https://arxiv.org/pdf/1904.10509.pdf) (from OpenAI Sparse Transformer)
* [BigBird](https://arxiv.org/pdf/2007.14062.pdf) (from Google)
* BSLongformer (Block-Sparse implementation of [Longformer](https://arxiv.org/pdf/2004.05150.pdf) from AI2)
We also define a template to have `variable` structure (top figure), which can be used to simply customize any block-sparse random/local/global attention pattern. In addition to this list, user can add any other sparsity structure as described in [tutorial](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_tutorials/sparse_transformer.md) section.
* **comparison with state of the art, Longformer**
We compared SA with Longformer, a state-of-the-art sparse structure and implementation. In our experiment, SA uses `Fixed` sparsity, and two implementations have comparable accuracy. On system performance, SA outperforms Longformer both in training and inference:
* 1.47x faster execution pre-training MLM on Wikitext103
We ran an experiment following the [notebook](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) offered by Longformer. In this experiment, we pre-train an MLM model using RoBERTa-base checkpoint. This is done on 8 V100-SXM2 GPU. Following table shows the details of the result in which using DeepSpeed Sparse Attention shows 1.47x speed up.
|Model |Local Window Size |BPC |Train Step |Time Per Iteration |Time Improvement |Accuracy improvement |
|-------------------|------------------|--------|------------|--------------------|------------------|----------------------|
|RoBERTa Checkpoint | |2.5326 | |
|Longformer |512 |2.6535 |0 | |1.47 |1.01 |
|Sparse Attention | |2.6321 | | | | |
|Longformer | |1.6708 |3k |1.6280 | |1.01 |
|Sparse Attention | |1.6613 | |1.1059 | | |
|Longformer |64 |5.7840 |0 | |1.31 |1.46 |
|Sparse Attention | |3.9737 | | | | |
|Longformer | |2.0466 |3k |1.4855 | |1.09 |
|Sparse Attention | |1.8693 | |1.1372 | | |
* 3.13x faster execution inference on BERT-Base
Through our Long Document Comprehension application we described above, we also checked the inference time for different window sizes testing BERT model on a `2,048` Sequence Length and batch size `1`. In this experiment, we noticed up to `3.13X` speed up replacing Bert Attention with DeepSpeed Sparse Attention instead of Longformer Attention. Following table shows the complete result.
|Local Window Size |Time Improvement|
|--------------------|----------------|
|512 |3.13 |
|256 |2.29 |
|128 |2.16 |
|64 |1.5 |
|32 |1.24 |
|16 |1.23 |
---
title: "DeepSpeed Sparse Attention"
---
In this tutorial we describe how to use DeepSpeed Sparse Attention and its building-block kernels through DeepSpeed launcher or integrating individual kernels into your code.
**Note:** Currently DeepSpeed Sparse Attention can be used only on Nvidia V100 GPU using Cuda 10.1 or 10.2.
{: .notice--warning}
## How to use
DeepSpeed Sparse Attention can be used as a feature through DeepSpeed, or simply integrated with any Transformer model as a self-attention module alone. Further, the building block kernels, matrix multiplication and softmax can be used separately. To use sparse attention alone, you can simply install DeepSpeed and import any of the following modules from it; example:
```python
from deepspeed.ops.sparse_attention import SparseSelfAttention
```
Following we describe Sparse Attention modules:
* **MatMul**: This module handles block-sparse matrix-matrix multiplication. Currently it supports SDD, DSD, and DDS as described in [DeepSpeed Sparse Attention](https://github.com/microsoft/DeepSpeed-internal/tree/master/docs/_posts/2020-09-09-sparse-attention.md) section.
* **Softmax**: This module applies block sparse softmax. It handles both forward and backward pass.
* **SparseSelfAttention**: This module uses MatMul and Softmax kernels and generates Context Layer output given Query, Keys and Values. It is a simplified version of common operations in any self-attention layer. It can also apply:
* `Relative position embedding`
* `Attention mask`
* `Key padding mask`
on the intermediate attention scores. For more details about SelfAttantion, please check [MultiHeadAttention](https://pytorch.org/docs/master/generated/torch.nn.MultiheadAttention.html#multiheadattention).
* **BertSparseSelfAttention**: This module contains a simplified BertSelfAttention layer that can be used instead of original dense Bert Self-Attention layer. Our implementation is based on [DeepSpeedExample](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373-#L434).
* **SparseAttentionUtils**: This module provides few utility functions to handle adapting pre-trained model with sparse attention:
* `replace_model_self_attention_with_sparse_self_attention`: If you have currently loaded a model and want to replace self-attention module with sparse self-attention, you can simply use this function to handle it for you. It currently handles BERT and RoBERTa based pre-trained models, but you can extend it base on your model type if it is different from these two. You also need to extend the position embedding to handle new sequence length; this can be done using `extend_position_embedding` function.
* `update_tokenizer_model_max_length`: This function simply updates maximum position embedding in your tokenizer with the new value.
* `extend_position_embedding`: This function extends the position embedding based on the current values. For example, if you have a 128 max sequence length model and extending it to a 1k sequence length, it replicates current embeddings 8 times to initialize new embedding. Experimentally we have seen such initialization works much better than initializing from scratch; leads to faster convergence.
* `pad_to_block_size`: This function pads input tokens and attention mask on sequence length dimension to be multiple of block size; this is a requirement for SA.
* `unpad_sequence_output`: This function unpads sequence output if inputs of the model were padded.
* **SparsityConfig**: this is an abstract class for sparsity structure. Any sparsity structure extends this class and writes its own `make_layout` function. DeepSpeed currently provides the following structures that will be described in next section:
* `FixedSparsityConfig`
* `BSLongformerSparsityConfig`
* `BigBirdSparsityConfig`
* `VariableSparsityConfig`
### BertSparseSelfAttention Example
We have currently integrated Sparse Attention with our [bing_bert](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py) code that can be used as an example for integration. In this example, we replace, BertSelfAttention module with BertSparseSelfAttention. Using DeepSpeed launcher, you can enable sparse attention using `deepspeed_sparse_attention` argument ([example](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/ds_sa_train_bert_bsz64k_seq128.sh)) and add your desired sparsity config into the [DeepSpeed config file](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_bsz64k_lamb_config_seq128.json). In this example, we have used `fixed` sparsity mode. Further, you need to pad sequence dimension of `input_ids` and `attention_mask` to be multiple of sparse block size. As mentioned above, DeepSpeed provides utility functions for padding and unpadding and you can check our [example](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py) to see where and how pad and unpad the inputs or outputs of the model.
**Note:** Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!
{: .notice--warning}
### Sparsity structures
Following we describe supported sparsity structures, their parameter set and the flexibility of adding arbitrary sparsity pattern on the self-attention layer.
* **SpasityConfig**:
This module, is the parent class for all sparsity structures and contains the shared features of all sparsity structures. It takes the following parameters:
* `num_heads`: an integer determining number of attention heads of the layer.
* `block`: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such square blocks; `Block X Block`.
* `different_layout_per_head`: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
* **Fixed** (FixedSparistyConfig):
This structure is based on [Generative Modeling with Sparse Transformers](https://arxiv.org/abs/1904.10509) from OpenAI, in which local and global attention is fixed by the given parameters:
* `num_local_blocks`: an integer determining the number of blocks in local attention window. As it is illustrated in the below figure (adapted from original paper), tokens in a local window, attend to all tokens local to them. In the case of autoregressive model, as in the figure, tokens attend to tokens appearing before them in the local window. And in the case of Masked model such as BERT, attention is bidirectional.
* `num_global_blocks`: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; illustrated in the figure below as well.
* `attention`: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
* `horizontal_global_attention`: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
* `num_different_global_patterns`: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks constructing local window and global attention size of a single block, we can have 4 different versions in which the first, second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on `num_local_blocks` and `num_global_blocks`. Further, if you set this to more than one, you need to set `different_layout_per_head` to `True`.
![Fixed sparsity structure](/assets/images/sa_fixed_sparsity_structure.png)
* **BSLongformer** (BSLongformerSparistyConfig):
This structure is an edited version of [Longformer: The Long-Document Transformer](https://arxiv.org/pdf/2004.05150.pdf), in which instead of single token-wise sparsity, we offer block of tokens sparsity. Parameters that define this patters are:
* `num_sliding_window_blocks`: an integer determining the number of blocks in sliding local attention window.
* `global_block_indices`: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if `global_block_end_indices` parameter is set, this parameter is used as starting index of each global window.
* `global_block_end_indices`: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size as `global_block_indices` parameter, and combining this two parameters, for each index `i`, blocks from `global_block_indices[i]` to `global_block_end_indices[i]` (exclusive) are considered as global attention block.
* **BigBird** (BigBirdSparsityConfig):
This structure is based on [Big Bird: Transformers for Longer Sequences](https://arxiv.org/pdf/2007.14062.pdf). It somehow combines the idea of `fixed` and `longformer` patterns along with random attention. Following parameters define this structure:
* `num_random_blocks`: an integer determining how many blocks in each row block are attended randomly.
* `num_sliding_window_blocks`: an integer determining the number of blocks in sliding local attention window.
* `num_global_blocks`: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well.
* **Variable** (VariableSparsityConfig):
This structure also combines the idea of local, global and random attention. Further, it has the flexibility of defining variable size local windows. Following is the list of parameters that define this structure:
* `num_random_blocks`: an integer determining how many blocks in each row block are attended randomly.
* `local_window_blocks`: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second number the second window, ..., and the last number determines the number of blocks in the remaining local windows.
* `global_block_indices`: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if `global_block_end_indices` parameter is set, this parameter is used as starting index of each global window.
* `global_block_end_indices`: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size as `global_block_indices` parameter, and combining this two parameters, for each index `i`, blocks from `global_block_indices[i]` to `global_block_end_indices[i]` (exclusive) are considered as global attention block.
* `attention`: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
* `horizontal_global_attention`: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks
Figure bellow illustrates an example of `variable` sparsity, in which blue, orange and green blocks illustrate local, global, and random attention blocks respectively.
![Variable sparsity structure](/assets/images/sa_variable_sparsity_structure.png)
Further, we provide a `dense` pattern (`DenseSparsityConfig`), that can be used for the sake of testing while it represents the full attention.
### How to expand block-base sparsity patterns
Our building block kernels, block-based `MatMul` & `Softmax`, can accept any block-based sparsity. This provides the flexibility to apply any block-based sparsity pattern to attention score. To define and apply a new sparsity pattern, you can simply follow any of the above sparsity structures. You need to add a new class that expands `SparsityConfig` and define `make_layout` function based on how your sparsity is structured. You can add any extra parameters you may need or just use default parameters of the parent class.
...@@ -19,6 +19,9 @@ training](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). ...@@ -19,6 +19,9 @@ training](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html).
To use transformer kernel for training a model, you should Integrate DeepSpeed into your training script using the [Getting Started](/getting-started/) guide. To use transformer kernel for training a model, you should Integrate DeepSpeed into your training script using the [Getting Started](/getting-started/) guide.
**Note:** Currently DeepSpeed Transformer Kernels do not support Sparse Attention. To use Sparse Attention, you need to disable Transformer Kernels!
{: .notice--warning}
### **Integrate Transformer Kernel** ### **Integrate Transformer Kernel**
First of all, you need to integrate transformer kernel into the top-level model. Here, we show an example of instantiating the transformer kernel using the Pre-LN BERT-Large configuration settings. This configuration has 24 layers with 1024 hidden-dimension and uses the sequence length of 128 and batch size of 64. To add all these layers, we copy the same layer specification `num_hidden_layer` times with different IDs inside a ModuleList. First of all, you need to integrate transformer kernel into the top-level model. Here, we show an example of instantiating the transformer kernel using the Pre-LN BERT-Large configuration settings. This configuration has 24 layers with 1024 hidden-dimension and uses the sequence length of 128 and batch size of 64. To add all these layers, we copy the same layer specification `num_hidden_layer` times with different IDs inside a ModuleList.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment