Unverified Commit e5bbc2e5 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Sparse attn + ops/runtime refactor + v0.3.0 (#343)



* Sparse attn + ops/runtime refactor + v0.3.0
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent 838f53b7
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import torch
import random
class SparsityConfig:
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality.
"""
def __init__(self, num_heads, block=16, different_layout_per_head=False):
"""Initialize the Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
"""
self.num_heads = num_heads
self.block = block
self.different_layout_per_head = different_layout_per_head
self.num_layout_heads = num_heads if different_layout_per_head else 1
def setup_layout(self, seq_len):
"""Create layout tensor for the given sequence length
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout of all head; initialized with zero
"""
if (seq_len % self.block != 0):
raise ValueError(
f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!'
)
num_blocks = seq_len // self.block
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64)
return layout
def check_and_propagate_first_head_layout(self, layout):
"""If all heads require same sparsity layout, it propagate first head layout to all heads
Arguments:
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head
"""
if not self.different_layout_per_head:
layout[1:self.num_heads, :, :] = layout[0, :, :]
return layout
class DenseSparsityConfig(SparsityConfig):
"""Configuration class to store `Dense` configuration.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
"""
def __init__(self, num_heads, block=16, different_layout_per_head=False):
"""Initialize the Dense Sparsity Pattern Config.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
seq_len: required: an integer determining number of attention heads of the layer.
different_layout_per_head: optional: this is just for the sake of consistency with other sparsity formats; can ignore it for DenseSparsityConfig
"""
super().__init__(num_heads, block, different_layout_per_head)
def make_layout(self, seq_len):
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
Arguments:
seq_len: required: an integer determining the underling sequence length; must be <= max sequence length
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; for dense everything is 1
"""
layout = self.setup_layout(seq_len)
layout[:, :, :] = 1
return layout
class FixedSparsityConfig(SparsityConfig):
"""Configuration class to store `Fixed` sparsity configuration.
For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(self,
num_heads,
block=16,
different_layout_per_head=False,
num_local_blocks=4,
num_global_blocks=1,
attention='bidirectional',
horizontal_global_attention=False,
num_different_global_patterns=1):
"""Initialize `Fixed` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
num_local_blocks: optional: an integer determining the number of blocks in local attention window.
num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
num_different_global_patterns: optional: an integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative. For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different versions in which the first, Second, third, or forth block of each local window can be global representative of that window. This parameter determines how many of such patterns we want. Of course, there is a limitation based on num_local_blocks and num_global_blocks.
"""
super().__init__(num_heads, block, different_layout_per_head)
self.num_local_blocks = num_local_blocks
if (num_local_blocks % num_global_blocks != 0):
raise ValueError(
f'Number of blocks in a local window, {num_local_blocks}, must be dividable by number of global blocks, {num_global_blocks}!'
)
self.num_global_blocks = num_global_blocks
if (attention != 'unidirectional' and attention != 'bidirectional'):
raise NotImplementedError(
'only \"uni/bi-directional\" attentions are supported for now!')
self.attention = attention
if (attention != 'bidirectional' and horizontal_global_attention):
raise ValueError(
'only \"bi-directional\" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
if (num_different_global_patterns > 1 and not different_layout_per_head):
raise ValueError(
f'Number of different layouts cannot be more than one when you have set a single layout for all heads! Set different_layout_per_head to True.'
)
if (num_different_global_patterns > (num_local_blocks // num_global_blocks)):
raise ValueError(
f'Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, cannot be larger than number of local window blocks divided by number of global blocks, {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!'
)
self.num_different_global_patterns = num_different_global_patterns
def set_local_layout(self, h, layout):
"""Sets local attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
for i in range(0, num_blocks, self.num_local_blocks):
end = min(i + self.num_local_blocks, num_blocks)
for row in range(i, end):
for col in range(
i,
(row + 1 if self.attention == 'unidirectional' else end)):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attantion layout used by the given head in the sparse attention.
Currently we set global blocks starting from the last block of a local window to the first one. That means if a local window consists of 4 blocks and global attention size is one block, we use block #4 in each local window as global. If we have different layout per head, then other heads will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global attentions, multiple head may have same global attentions.
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and vertically.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
first_global_block_idx = self.num_local_blocks - (
1 + h % self.num_different_global_patterns) * self.num_global_blocks
# set all global blocks except the last one if (in last local window)
end = num_blocks - (num_blocks % self.num_local_blocks)
for i in range(first_global_block_idx, end, self.num_local_blocks):
# vertical global attention
first_row = 0 if self.attention == 'bidirectional' else i
#(((i // self.num_local_blocks) + 1) * self.num_local_blocks)
#if (first_row < num_blocks):
layout[h, first_row:, i:i + self.num_global_blocks] = 1
# horizontal global attention; only in bidirectional attention
if (self.horizontal_global_attention):
layout[h, i:i + self.num_global_blocks, :] = 1
# set last global blocks; handle possible short last local window
if (end < num_blocks):
start = min(end + first_global_block_idx,
num_blocks - self.num_global_blocks)
end = start + self.num_global_blocks
# vertical global attention
first_row = 0 if self.attention == 'bidirectional' else start
#(((start // self.num_local_blocks) + 1) * self.num_local_blocks)
#if (first_row < num_blocks):
layout[h, first_row:, start:end] = 1
# horizontal global attention
if (self.horizontal_global_attention):
layout[h, start:end, :] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class VariableSparsityConfig(SparsityConfig):
"""Configuration class to store `Variable` sparsity configuration.
This layout is an extension of FixedSparsityConfig in which:
- user can set random layout; default value is zero means no random block
- user can provide a list of local block sizes
- user can provide a list of global block indices.
For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(self,
num_heads,
block=16,
different_layout_per_head=False,
num_random_blocks=0,
local_window_blocks=[4],
global_block_indices=[0],
global_block_end_indices=None,
attention='bidirectional',
horizontal_global_attention=False):
"""Initialize `Variable` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability. Currently this sparsity config can only assign single layout to all heads; needs to be extended for different layout per head.
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
local_window_blocks: optional: a list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows.
global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
num_global_blocks: optional: an integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`, such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty as above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is `bidirectional`. Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks.
"""
super().__init__(num_heads, block, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.local_window_blocks = local_window_blocks
self.global_block_indices = global_block_indices
if (global_block_end_indices is not None):
if (len(global_block_indices) != len(global_block_end_indices)):
raise ValueError(
f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
)
for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
if start_idx >= end_idx:
raise ValueError(
f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
)
self.global_block_end_indices = global_block_end_indices
if (attention != 'unidirectional' and attention != 'bidirectional'):
raise NotImplementedError(
'only \"uni/bi-directional\" attentions are supported for now!')
self.attention = attention
if (attention != 'bidirectional' and horizontal_global_attention):
raise ValueError(
'only \"bi-directional\" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
def set_random_layout(self, h, layout):
"""Sets random attantion layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if (num_blocks < self.num_random_blocks):
raise ValueError(
f'Number of random blocks, {self.num_random_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)
for row in range(0, num_blocks):
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_local_layout(self, h, layout):
"""Sets local attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
start_block_idx = 0
end_block_idx = 0
for block_size in self.local_window_blocks:
end_block_idx += block_size
end_block_idx = min(end_block_idx, num_blocks)
for row in range(start_block_idx, end_block_idx):
for col in range(
start_block_idx,
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
layout[h, row, col] = 1
start_block_idx += block_size
# if there is any remaining not attended part, use the lats local window block size as local window for the remaining applicable local windows
for i in range(start_block_idx, num_blocks, block_size):
end_block_idx = min(i + block_size, num_blocks)
for row in range(i, end_block_idx):
for col in range(
i,
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if (self.global_block_end_indices is None):
for idx in self.global_block_indices:
# if global block idx is in the range of the sequnce blocks
if (idx < num_blocks):
#global rows
if (self.horizontal_global_attention):
layout[h, idx, :] = 1
#global columns
first_row = 0 if self.attention == 'bidirectional' else idx
layout[h, first_row:, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
# if global block idx is in the range of the sequnce blocks
if (start_idx < num_blocks):
end_idx = min(end_idx, num_blocks)
#global rows
if (self.horizontal_global_attention):
layout[h, start_idx:end_idx, :] = 1
#global columns
first_row = 0 if self.attention == 'bidirectional' else start_idx
layout[h, first_row:, start_idx:end_idx] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BigBirdSparsityConfig(SparsityConfig):
"""Configuration class to store `BigBird` sparsity configuration.
For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
"""
def __init__(self,
num_heads,
block=16,
different_layout_per_head=False,
num_random_blocks=1,
num_sliding_window_blocks=3,
num_global_blocks=1):
"""Initialize the BigBird Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
num_global_blocks: optional: an integer determining how many consecutive blocks, starting from index 0, are considered as global attention. Global block tokens will be attended by all other block tokens and will attend to all other block tokens as well.
"""
super().__init__(num_heads, block, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.num_sliding_window_blocks = num_sliding_window_blocks
self.num_global_blocks = num_global_blocks
def set_random_layout(self, h, layout):
"""Sets random attantion layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless `different_layout_per_head` parameter is set in which each head can have a different random layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if (num_blocks < self.num_random_blocks):
raise ValueError(
f'Number of random blocks, {self.num_random_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)
for row in range(0, num_blocks):
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if (num_blocks < self.num_sliding_window_blocks):
raise ValueError(
f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout_itc(self, h, layout):
"""Sets global attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if (num_blocks < self.num_global_blocks):
raise ValueError(
f'Number of global blocks, {self.num_global_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)
#global rows
layout[h, 0:self.num_global_blocks, :] = 1
#global columns
layout[h, :, 0:self.num_global_blocks] = 1
return layout
def make_layout(self, seq_len):
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout_itc(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BSLongformerSparsityConfig(SparsityConfig):
"""Configuration class to store edited `Longformer` sparsity configuration.
Note) this is a block-sparse version of the Longformer which is slightly different than original Longformer; which is element-wise sparsity.
For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
"""
def __init__(self,
num_heads,
block=16,
different_layout_per_head=False,
num_sliding_window_blocks=3,
global_block_indices=[0],
global_block_end_indices=None):
"""Initialize the edited `Longformer` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a different sparsity layout; default is false and this will be satisfied based on availability.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding local attention window.
global_block_indices: optional: a list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Default value is only index 0. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are considered as global attention.
"""
super().__init__(num_heads, block, different_layout_per_head)
self.num_sliding_window_blocks = num_sliding_window_blocks
self.global_block_indices = global_block_indices
if (global_block_end_indices is not None):
if (len(global_block_indices) != len(global_block_end_indices)):
raise ValueError(
f'Global block start indices length, {len(global_block_indices)}, must be same as global block end indices length, {len(global_block_end_indices)}!'
)
for _, (start_idx, end_idx) in enumerate(zip(global_block_indices, global_block_end_indices)):
if start_idx >= end_idx:
raise ValueError(
f'Global block start index, {start_idx}, must be smaller than global block end index, {end_idx}!'
)
self.global_block_end_indices = global_block_end_indices
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if (num_blocks < self.num_sliding_window_blocks):
raise ValueError(
f'Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than overal number of blocks in a row, {num_blocks}!'
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attantion layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head; may not be completly set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if (self.global_block_end_indices is None):
for idx in self.global_block_indices:
# if global block idx is in the range of the sequnce blocks
if (idx < num_blocks):
#global rows
layout[h, idx, :] = 1
#global columns
layout[h, :, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(zip(self.global_block_indices, self.global_block_end_indices)):
# if global block idx is in the range of the sequnce blocks
if (start_idx < num_blocks):
end_idx = min(end_idx, num_blocks)
#global rows
layout[h, start_idx:end_idx, :] = 1
#global columns
layout[h, :, start_idx:end_idx] = 1
return layout
def make_layout(self, seq_len):
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
import sys
import os
def _build_file_index(directory, suffix='.tr'):
"""Build an index of source files and their basenames in a given directory.
Args:
directory (string): the directory to index
suffix (string): index files with this suffix
Returns:
list: A list of tuples of the form [(basename, absolute path), ...]
"""
index = []
for fname in os.listdir(directory):
if fname.endswith(suffix):
basename = fname[:fname.rfind(suffix)] # strip the suffix
path = os.path.join(directory, fname)
index.append((basename, path))
return index
# Go over all local source files and parse them as strings
_module = sys.modules[_build_file_index.__module__]
_directory = os.path.dirname(os.path.realpath(__file__))
for name, fname in _build_file_index(_directory):
with open(fname, 'r') as fin:
setattr(_module, name, fin.read())
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
__global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
float scale,
TYPE* DX __readonly __noalias __aligned(16),
int* LUT,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zdx __multipleof(BLOCK)) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
// bounds checking on lut
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// initialize pointers to block-sparse input
long blockid[TN] = *(LUT + offset + rbmn*4);
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
TYPE* pdx[TN] = DX + pidz * stride_zdx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
// compute fused softmax backward
TYPE x[TN] = check ? *px : 0;
TYPE dx[TN] = check ? *pdx : 0;
float Fdx[TN] = dx;
float Fx[TN] = x;
float Fxdx[TN] = Fdx*Fx;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
}
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
__global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
float scale,
int *LUT __readonly __noalias __aligned(16),
TYPE *RPE __readonly __noalias __aligned(16),
TYPE *KP_M __readonly __noalias __aligned(16),
TYPE *ATTN_M __readonly __noalias __aligned(16),
int num_blocks,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zrpe __multipleof(BLOCK),
int stride_hrpe __multipleof(BLOCK),
int stride_srpe __multipleof(BLOCK),
int stride_zkpm __multipleof(BLOCK),
int stride_zattnm __multipleof(BLOCK)){
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// block id and column id
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
// pointers to X
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
#ifdef APPLY_RPE
// pointers to relative position embedding
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
+ headid * stride_hrpe
+ columnid * BLOCK
+ rowid * BLOCK * stride_srpe
+ rxm * stride_srpe
+ rxn;
#endif
#ifdef APPLY_KP_MASK
// pointers to key padding mask
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
+ columnid * BLOCK
+ rxn;
#endif
#ifdef APPLY_ATTN_MASK
// pointers to attention mask
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
+ rowid * BLOCK * stride_zattnm
+ rxm * stride_zattnm
+ rxn;
#endif
// load input
TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE
// load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0;
#endif
#ifdef APPLY_KP_MASK
// load key-padding mask
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif
#ifdef APPLY_ATTN_MASK
// load attention mask
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif
// compute softmax in float
#ifdef APPLY_RPE
float Frpe[TN] = rpe;
#endif
#ifdef APPLY_KP_MASK
float Fkp_m[TN] = kp_m;
#endif
#ifdef APPLY_ATTN_MASK
float Fattn_m[TN] = attn_m;
#endif
#ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
#endif
#ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
#endif
float Fx[TN] = x;
#ifdef APPLY_SCALE
Fx = Fx * scale; // apply scale
#endif
#ifdef APPLY_RPE
Fx = Fx + Frpe; // apply relative position embedding
#endif
#ifdef APPLY_KP_MASK
Fx = Fx + Fkp_m; // apply key padding mask
#endif
#ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask
#endif
float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+];
// write-back in half/float
TYPE y[TN] = Fy;
TYPE ysum = Fysum;
*?(check)px = y / ysum;
}
from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from torch import nn
from torch.autograd import Function
import torch
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import json
import math
import deepspeed_transformer_cuda as ds_transformer_cuda
import deepspeed_stochastic_transformer_cuda as ds_stochastic_transformer_cuda
import importlib
import torch
from torch import nn
from torch.autograd import Function
# Cuda modules will be imported if needed
transformer_cuda_module = None
stochastic_transformer_cuda_module = None
class TransformerConfig():
......@@ -159,7 +165,7 @@ class DeepSpeedTransformerFunction(Function):
if bsz > config.batch_size:
raise ValueError('Input batch size exceeds the limit.')
cuda_module = ds_stochastic_transformer_cuda if config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
(output,
......@@ -321,7 +327,7 @@ class DeepSpeedTransformerFunction(Function):
norm_w,
norm_b) = ctx.saved_tensors
cuda_module = ds_stochastic_transformer_cuda if ctx.config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module
backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32
(grad_input,
......@@ -457,8 +463,22 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_w = initial_weights[7]
self.norm_b = initial_biases[7]
# Import cuda modules if needed
global transformer_cuda_module, stochastic_transformer_cuda_module
if transformer_cuda_module is None or stochastic_transformer_cuda_module is None:
try:
transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.transformer_cuda")
stochastic_transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.stochastic_transformer_cuda")
except ImportError as err:
print(
"Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions."
)
raise err
# create the layer in cuda kernels.
cuda_module = ds_stochastic_transformer_cuda if self.config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module
create_layer_func = cuda_module.create_transformer_layer_fp16 if self.config.fp16 else cuda_module.create_transformer_layer_fp32
create_layer_func(self.config.layer_id,
......
......@@ -13,16 +13,17 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import copy
import torch.distributed as dist
import torch
import contextlib
import torch.distributed as dist
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
import torch.distributed as dist
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
#DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
......
......@@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param
#########################################
# DeepSpeed Activation Checkpointing
......
......@@ -6,12 +6,12 @@ Licensed under the MIT license.
import torch
import json
import copy
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
......@@ -158,6 +158,177 @@ def get_gradient_clipping(param_dict):
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
def get_sparse_attention(param_dict):
if SPARSE_ATTENTION in param_dict.keys():
sparsity = param_dict[SPARSE_ATTENTION]
mode = get_sparse_attention_mode(sparsity)
if (mode == SPARSE_DENSE_MODE):
return get_sparse_dense_config(sparsity)
elif (mode == SPARSE_FIXED_MODE):
return get_sparse_fixed_config(sparsity)
elif (mode == SPARSE_VARIABLE_MODE):
return get_sparse_variable_config(sparsity)
elif (mode == SPARSE_BIGBIRD_MODE):
return get_sparse_bigbird_config(sparsity)
elif (mode == SPARSE_BSLONGFORMER_MODE):
return get_sparse_bslongformer_config(sparsity)
else:
raise NotImplementedError(
f'Given sparsity mode, {mode}, has not been implemented yet!')
else:
return None
def get_sparse_dense_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
def get_sparse_fixed_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_local_blocks = get_scalar_param(sparsity,
SPARSE_NUM_LOCAL_BLOCKS,
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
num_global_blocks = get_scalar_param(sparsity,
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
attention = get_scalar_param(sparsity,
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
num_differnt_global_patterns = get_scalar_param(
sparsity,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT)
return {
SPARSE_MODE: SPARSE_FIXED_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_differnt_global_patterns
}
def get_sparse_variable_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_random_blocks = get_scalar_param(sparsity,
SPARSE_NUM_RANDOM_BLOCKS,
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
local_window_blocks = get_scalar_param(sparsity,
SPARSE_LOCAL_WINDOW_BLOCKS,
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
attention = get_scalar_param(sparsity,
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
return {
SPARSE_MODE: SPARSE_VARIABLE_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention
}
def get_sparse_bigbird_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_random_blocks = get_scalar_param(sparsity,
SPARSE_NUM_RANDOM_BLOCKS,
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
num_global_blocks = get_scalar_param(sparsity,
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
return {
SPARSE_MODE: SPARSE_BIGBIRD_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks
}
def get_sparse_bslongformer_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
return {
SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices
}
def get_sparse_attention_mode(param_dict):
if SPARSE_MODE in param_dict.keys():
return param_dict[SPARSE_MODE]
else:
return SPARSE_MODE_DEFAULT
def get_sparse_attention_type(param_dict):
if SPARSE_ATTENTION_TYPE in param_dict.keys():
return param_dict[SPARSE_ATTENTION_TYPE]
else:
return SPARSE_ATTENTION_TYPE_DEFAULT
def get_optimizer_name(param_dict):
if OPTIMIZER in param_dict.keys() and \
TYPE in param_dict[OPTIMIZER].keys():
......@@ -358,6 +529,8 @@ class DeepSpeedConfig(object):
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
self.sparse_attention = get_sparse_attention(param_dict)
def _batch_assertion(self):
train_batch = self.train_batch_size
......
......@@ -17,6 +17,42 @@ ROUTE_ENCODE = "encode"
TRAIN_BATCH_SIZE = "train_batch_size"
TRAIN_BATCH_SIZE_DEFAULT = None
#############################################
# Sparse attention
#############################################
SPARSE_ATTENTION = "sparse_attention"
SPARSE_DENSE_MODE = "dense"
SPARSE_FIXED_MODE = "fixed"
SPARSE_VARIABLE_MODE = "variable"
SPARSE_BIGBIRD_MODE = "bigbird"
SPARSE_BSLONGFORMER_MODE = "bslongformer"
SPARSE_MODE = "mode"
SPARSE_MODE_DEFAULT = SPARSE_FIXED_MODE
SPARSE_BLOCK = "block"
SPARSE_BLOCK_DEFAULT = 16
SPARSE_DIFFERENT_LAYOUT_PER_HEAD = "different_layout_per_head"
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT = False
SPARSE_NUM_LOCAL_BLOCKS = "num_local_blocks"
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT = 4
SPARSE_NUM_GLOBAL_BLOCKS = "num_global_blocks"
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT = 1
SPARSE_ATTENTION_TYPE = "attention"
SPARSE_ATTENTION_TYPE_DEFAULT = "bidirectional"
SPARSE_HORIZONTAL_GLOBAL_ATTENTION = "horizontal_global_attention"
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT = False
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS = "num_differnt_global_patterns"
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT = 1
SPARSE_NUM_RANDOM_BLOCKS = "num_random_blocks"
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT = 0
SPARSE_LOCAL_WINDOW_BLOCKS = "local_window_blocks"
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT = [4]
SPARSE_GLOBAL_BLOCK_INDICES = "global_block_indices"
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT = [0]
SPARSE_GLOBAL_BLOCK_END_INDICES = "global_block_end_indices"
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT = None
SPARSE_NUM_SLIDING_WINDOW_BLOCKS = "num_sliding_window_blocks"
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT = 3
#############################################
# Optimizer and lr scheduler
#############################################
......
......@@ -2,36 +2,35 @@
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch
import os
import torch
import warnings
import torch.distributed as dist
from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from apex import amp
from tensorboardX import SummaryWriter
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.pt.log_utils import logger
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import \
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.ops.lamb import FusedLamb
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId"
......@@ -92,7 +91,7 @@ def print_configuration(args, name):
logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg)))
class DeepSpeedLight(Module):
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training.
"""
def __init__(self,
......@@ -106,7 +105,7 @@ class DeepSpeedLight(Module):
dist_init_required=None,
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()
super(DeepSpeedEngine, self).__init__()
self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
......
......@@ -9,9 +9,9 @@ import torch
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
class FP16_Optimizer(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment