Unverified Commit e5bbc2e5 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Sparse attn + ops/runtime refactor + v0.3.0 (#343)



* Sparse attn + ops/runtime refactor + v0.3.0
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent 838f53b7
...@@ -20,7 +20,7 @@ AllowShortLoopsOnASingleLine: true ...@@ -20,7 +20,7 @@ AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes AlwaysBreakTemplateDeclarations: true
BinPackArguments: false BinPackArguments: false
BinPackParameters: false BinPackParameters: false
BraceWrapping: BraceWrapping:
......
...@@ -8,7 +8,7 @@ deepspeed/git_version_info.py ...@@ -8,7 +8,7 @@ deepspeed/git_version_info.py
# Build + installation data # Build + installation data
build/ build/
dist/ dist/
fused_lamb_*.so *.so
deepspeed.egg-info/ deepspeed.egg-info/
# Website # Website
......
...@@ -8,7 +8,8 @@ RUN apt-get update && \ ...@@ -8,7 +8,8 @@ RUN apt-get update && \
software-properties-common \ software-properties-common \
openssh-client openssh-server \ openssh-client openssh-server \
pdsh curl sudo net-tools \ pdsh curl sudo net-tools \
vim iputils-ping wget vim iputils-ping wget \
llvm-9-dev cmake
############################################################################## ##############################################################################
# Installation Latest Git # Installation Latest Git
...@@ -85,7 +86,7 @@ RUN mkdir -p ${STAGE_DIR} && \ ...@@ -85,7 +86,7 @@ RUN mkdir -p ${STAGE_DIR} && \
dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb dpkg -i ${STAGE_DIR}/nvidia-peer-memory_${NV_PEER_MEM_TAG}_all.deb
############################################################################## ##############################################################################
## Ucomment and set SSH Daemon port ## SSH daemon port inside container cannot conflict with host OS port
############################################################################### ###############################################################################
ENV SSH_PORT=2222 ENV SSH_PORT=2222
RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \ RUN cat /etc/ssh/sshd_config > ${STAGE_DIR}/sshd_config && \
......
jobs: jobs:
- job: Default - job: DeepSpeed_Tests
timeoutInMinutes: 360 timeoutInMinutes: 360
pool: pool:
name: 'GPU_testing' name: 'DS_testing'
strategy: strategy:
matrix: matrix:
Python36: PyTorch12-CUDA100:
python.version: '3.6' python.version: '3.6'
#Python35: cuda.version: '10.0'
# python.version: '3.5' pytorch.version: '1.2'
#Python37: torchvision.version: '0.4.0'
runmodeltests: true
#PyTorch15-CUDA101:
# python.version: '3.7' # python.version: '3.7'
#Python38: # cuda.version: '10.1'
# python.version: '3.8' # pytorch.version: '1.5.0+cu101'
# torchvision.version: '0.6.0+cu101'
# runmodeltests: true
##PyTorch15-CUDA102:
# python.version: '3.7'
# cuda.version: '10.2'
# pytorch.version: '1.5'
# torchvision.version: '0.6.1'
# runmodeltests: true
variables:
conda_env: 'ds_test_py$(python.version)_cuda$(cuda.version)_pytorch$(pytorch.version)'
steps: steps:
- task: UsePythonVersion@0 # Unfortunately nvidia's nvcc_linux-64=<version> seems to install 10.1 regardless?
inputs: # Most of this complexity is a workaround to get the compiler toolchain to match the
versionSpec: '$(python.version)' # cudatoolkit runtime
addToPath: true - script: |
architecture: 'x64' conda create --force --yes -n $(conda_env) python=$(python.version) cudatoolkit=$(cuda.version)
displayName: 'Use Python $(python.version)' source activate $(conda_env)
conda install -q --yes conda
conda install -q --yes pip
conda install -q --yes gxx_linux-64
if [[ $(cuda.version) != "10.2" ]]; then conda install --yes -c conda-forge cudatoolkit-dev=$(cuda.version) ; fi
displayName: 'Setup environment python=$(python.version) pytorch=$(pytorch.version) cuda=$(cuda.version)'
# Manually install torch/torchvision first to enforce versioning.
- script: | - script: |
python -m pip install --upgrade pip source activate $(conda_env)
pip install --user -r requirements.txt pip install --progress-bar=off torch==$(pytorch.version) torchvision==$(torchvision.version)
./install.sh --pip_sudo #-f https://download.pytorch.org/whl/torch_stable.html
displayName: 'Install dependencies' ./install.sh --local_only
#python -I basic_install_test.py
displayName: 'Install DeepSpeed'
- script: | - script: |
pre-commit run --all-files source activate $(conda_env)
displayName: 'Formatting checks' which python
python --version
which nvcc
nvcc --version
which deepspeed
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
python -c "import deepspeed; print('deepspeed:', deepspeed.__version__)"
displayName: 'Show environment'
- script: | - script: |
pytest --forked --verbose tests/unit/ source activate $(conda_env)
pytest --durations=0 --forked --verbose -x tests/unit/
displayName: 'Unit tests' displayName: 'Unit tests'
- script: | - script: |
source activate $(conda_env)
ln -s /data/Megatron-LM/data DeepSpeedExamples/Megatron-LM/ ln -s /data/Megatron-LM/data DeepSpeedExamples/Megatron-LM/
pip install --user -r DeepSpeedExamples/Megatron-LM/requirements.txt pip install --progress-bar=off -r DeepSpeedExamples/Megatron-LM/requirements.txt
cd tests/model/ cd tests/model/
pytest -s run_sanity_check.py rm -rf BingBertSquad/baseline
rm -rf Megatron_GPT2/baseline
pytest --durations=0 -s run_sanity_check.py
condition: and(succeeded(), eq(variables['runmodeltests'], true))
displayName: 'Model tests' displayName: 'Model tests'
#BingBertSquad logs #BingBertSquad logs
...@@ -52,35 +86,29 @@ jobs: ...@@ -52,35 +86,29 @@ jobs:
targetPath: '$(Build.SourcesDirectory)/tests/model/BingBertSquad/test/' targetPath: '$(Build.SourcesDirectory)/tests/model/BingBertSquad/test/'
artifactName: BingBertSquad_logs artifactName: BingBertSquad_logs
displayName: 'BingBertSquad log uploads' displayName: 'BingBertSquad log uploads'
condition: always() condition: eq(variables['runmodeltests'], true)
# Megatron test logs
#- task: PublishPipelineArtifact@1
# inputs:
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/test/'
# artifactName: Megatron_GPT2_logs
# displayName: 'Megatron GPT2 log uploads'
# condition: always()
#- task: PublishPipelineArtifact@1 - job: Code_Quality_Checks
# inputs: pool:
# targetPath: '$(Build.SourcesDirectory)/tests/model/Megatron_GPT2/checkpoint_test_logs/' name: 'DS_testing'
# artifactName: Megatron_GPT2_checkpoint_logs variables:
# displayName: 'Megatron GPT2 checkpoint log uploads' conda_env: 'ds_codetest'
# condition: always()
steps:
- script: |
conda create --force --yes -n $(conda_env) python=3.7
source activate $(conda_env)
displayName: 'Create code test environment'
#BingBert logs - script: |
#- task: PublishPipelineArtifact@1 source activate $(conda_env)
# inputs: pip install pre-commit
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/pretrain_test/' pre-commit run --all-files
# artifactName: BingBert_pretrain_logs displayName: 'Formatting checks'
# displayName: 'BingBert pretrain logs'
# condition: always()
#- task: PublishPipelineArtifact@1 - script: |
# inputs: source activate $(conda_env)
# targetPath: '$(Build.SourcesDirectory)/tests/model/bing_bert/checkpoint_test_logs/' pip install pylint
# artifactName: BingBert_checkpoint_logs pylint --exit-zero deepspeed/
# displayName: 'BingBert checkpoint logs' displayName: 'Code linter'
# condition: always()
...@@ -18,7 +18,7 @@ except Exception as err: ...@@ -18,7 +18,7 @@ except Exception as err:
raise err raise err
try: try:
fused_lamb = importlib.import_module('deepspeed_lamb_cuda') fused_lamb = importlib.import_module('deepspeed.ops.lamb.fused_lamb_cuda')
print('deepspeed fused lamb kernels successfully installed') print('deepspeed fused lamb kernels successfully installed')
except Exception as err: except Exception as err:
raise err raise err
...@@ -30,7 +30,8 @@ except ImportError: ...@@ -30,7 +30,8 @@ except ImportError:
print("using new-style apex") print("using new-style apex")
try: try:
ds_transformer = importlib.import_module('deepspeed_transformer_cuda') ds_transformer = importlib.import_module(
'deepspeed.ops.transformer.transformer_cuda')
print('deepspeed transformer kernels successfully installed') print('deepspeed transformer kernels successfully installed')
except Exception as err: except Exception as err:
raise err raise err
#!/usr/bin/env python #!/usr/bin/env python
from deepspeed.pt.deepspeed_run import main from deepspeed.launcher.runner import main
if __name__ == '__main__': if __name__ == '__main__':
main() main()
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef std::vector<std::tuple<int, torch::Tensor>> ret_t;
void segment_blocks(torch::Tensor layout,
torch::Tensor idx,
torch::Tensor scratch,
int max_width,
ret_t& ret)
{
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
torch::Tensor tmp = torch::zeros_like(layout);
auto _tmp = tmp.accessor<int, 3>();
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
auto _scratch = scratch.accessor<int, 3>();
std::vector<int> current(H, 0);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (size_t h = 0; h < H; h++) {
// surrounding indices
std::vector<int> ii_left(max_width, -1);
std::vector<std::vector<int>> ii_top(max_width, std::vector<int>(N, -1));
for (size_t m = 0; m < M; m++) {
for (size_t n = 0; n < N; n++) {
int v = _layout[h][m][n];
if (v == 0) continue;
int n_left = ii_left[max_width - 1];
int m_top = ii_top[max_width - 1][n];
int top = (m_top >= 0) ? _tmp[h][m_top][n] : 0;
int left = (n_left >= 0) ? _tmp[h][m][n_left] : 0;
int topleft = (m_top >= 0 && n_left >= 0) ? _tmp[h][m_top][n_left] : 0;
int width = std::min(left, std::min(top, topleft)) + 1;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for (int nn = n_left + 1; nn < n; nn++)
if (ii_top[max_width - 1][nn] > ii_top[max_width - 1][n]) width = 1;
_tmp[h][m][n] = width;
// update n_left ring buffer
for (int k = 0; k < max_width - 1; k++) ii_left[k] = ii_left[k + 1];
ii_left[max_width - 1] = n;
// update ii_top ring buffer
for (int k = 0; k < max_width - 1; k++) ii_top[k][n] = ii_top[k + 1][n];
ii_top[max_width - 1][n] = m;
// block is too small -- skip
if (width != max_width) continue;
// retained blocks are set to zeros
for (size_t km = 0; km < max_width; km++)
for (size_t kn = 0; kn < max_width; kn++) {
int mm = ii_top[km][n];
int nn = ii_left[kn];
if (mm < 0 || nn < 0) continue;
_layout[h][mm][nn] = 0;
_tmp[h][mm][nn] = 0;
_scratch[h][current[h]][0] = (int)h;
_scratch[h][current[h]][1] = (int)mm;
_scratch[h][current[h]][2] = (int)nn;
_scratch[h][current[h]][3] = _idx[h][mm][nn];
current[h]++;
}
}
}
}
std::vector<torch::Tensor> to_cat;
for (size_t h = 0; h < H; h++)
if (current[h] > 0) to_cat.push_back(scratch[h].slice(0, 0, current[h]));
if (!to_cat.empty()) ret.push_back({max_width, torch::cat(to_cat)});
}
ret_t sdd_segment(torch::Tensor layout, int start_width)
{
ret_t ret;
// block index
torch::Tensor idx = torch::zeros_like(layout);
int current = 0;
size_t H = layout.size(0);
size_t M = layout.size(1);
size_t N = layout.size(2);
auto _layout = layout.accessor<int, 3>();
auto _idx = idx.accessor<int, 3>();
for (size_t h = 0; h < H; h++)
for (size_t m = 0; m < M; m++)
for (size_t n = 0; n < N; n++) {
if (_layout[h][m][n] == 0) continue;
_idx[h][m][n] = current++;
}
// scratch memory
torch::Tensor scratch = torch::empty({H, layout.sum().item<int>(), 4}, layout.dtype());
for (int max_width = start_width; max_width > 0; max_width /= 2)
segment_blocks(layout, idx, scratch, max_width, ret);
return ret;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("sdd_segment", &sdd_segment, "SDD segmentation handler");
}
''' '''
Copyright 2020 The Microsoft DeepSpeed Team Copyright 2020 The Microsoft DeepSpeed Team
''' '''
import sys
import types
from deepspeed.pt.deepspeed_light import DeepSpeedLight from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.pt.deepspeed_light import ADAM_OPTIMIZER, LAMB_OPTIMIZER from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.pt.deepspeed_lr_schedules import add_tuning_arguments from deepspeed.runtime.lr_schedules import add_tuning_arguments
from deepspeed.pt.log_utils import logger from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.pt.deepspeed_cuda import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed.runtime.activation_checkpointing import checkpointing
from deepspeed.pt.deepspeed_config import DeepSpeedConfig from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.utils import logger
import deepspeed.pt.deepspeed_checkpointing as checkpointing
try: try:
from deepspeed.git_version_info import git_hash, git_branch from deepspeed.git_version_info import version, git_hash, git_branch
except ImportError: except ImportError:
version = "0.0.0+unknown"
git_hash = None git_hash = None
git_branch = None git_branch = None
# Export version information # Export version information
__version_major__ = 0 version, __version_tag__ = version.split('+')
__version_minor__ = 2 __version_major__ = int(version.split('.')[0])
__version_patch__ = 0 __version_minor__ = int(version.split('.')[1])
__version_patch__ = int(version.split('.')[2])
__version__ = '.'.join( __version__ = '.'.join(
map(str, map(str,
[__version_major__, [__version_major__,
__version_minor__, __version_minor__,
__version_patch__])) __version_patch__]))
__version__ = f"{__version__}+{__version_tag__}"
__git_hash__ = git_hash __git_hash__ = git_hash
__git_branch__ = git_branch __git_branch__ = git_branch
# Provide backwards compatability with old deepspeed.pt module structure, should hopefully not be used
pt = types.ModuleType('pt', 'dummy pt module for backwards compatability')
deepspeed = sys.modules[__name__]
setattr(deepspeed, 'pt', pt)
setattr(deepspeed.pt, 'deepspeed_utils', deepspeed.runtime.utils)
sys.modules['deepspeed.pt'] = deepspeed.pt
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
def initialize(args, def initialize(args,
model, model,
...@@ -90,16 +104,16 @@ def initialize(args, ...@@ -90,16 +104,16 @@ def initialize(args,
__git_branch__), __git_branch__),
) )
engine = DeepSpeedLight(args=args, engine = DeepSpeedEngine(args=args,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
model_parameters=model_parameters, model_parameters=model_parameters,
training_data=training_data, training_data=training_data,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
mpu=mpu, mpu=mpu,
dist_init_required=dist_init_required, dist_init_required=dist_init_required,
collate_fn=collate_fn, collate_fn=collate_fn,
config_params=config_params) config_params=config_params)
return_items = [ return_items = [
engine, engine,
......
...@@ -10,7 +10,7 @@ import base64 ...@@ -10,7 +10,7 @@ import base64
from collections import defaultdict from collections import defaultdict
from argparse import ArgumentParser, REMAINDER from argparse import ArgumentParser, REMAINDER
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
def parse_args(): def parse_args():
......
...@@ -14,8 +14,8 @@ from copy import deepcopy ...@@ -14,8 +14,8 @@ from copy import deepcopy
import torch.cuda import torch.cuda
from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT from deepspeed.runtime.constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.pt.log_utils import logger from deepspeed.utils import logger
DLTS_HOSTFILE = "/job/hostfile" DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON"] EXPORT_ENVS = ["NCCL", "PYTHON"]
...@@ -285,7 +285,7 @@ def main(args=None): ...@@ -285,7 +285,7 @@ def main(args=None):
sys.executable, sys.executable,
"-u", "-u",
"-m", "-m",
"deepspeed.pt.deepspeed_launch", "deepspeed.launcher.launch",
"--world_info={}".format(world_info_base64), "--world_info={}".format(world_info_base64),
"--master_addr={}".format(args.master_addr), "--master_addr={}".format(args.master_addr),
"--master_port={}".format(args.master_port) "--master_port={}".format(args.master_port)
...@@ -328,7 +328,7 @@ def main(args=None): ...@@ -328,7 +328,7 @@ def main(args=None):
sys.executable, sys.executable,
"-u", "-u",
"-m", "-m",
"deepspeed.pt.deepspeed_launch", "deepspeed.launcher.launch",
'--world_info={}'.format(world_info_base64), '--world_info={}'.format(world_info_base64),
"--node_rank=%n", "--node_rank=%n",
"--master_addr={}".format(args.master_addr), "--master_addr={}".format(args.master_addr),
......
from deepspeed.ops.lamb.fused_lamb import FusedLamb
...@@ -3,46 +3,37 @@ Copyright 2019 The Microsoft DeepSpeed Team ...@@ -3,46 +3,37 @@ Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
''' '''
import types import types
import torch
import importlib import importlib
import torch
class FusedLamb(torch.optim.Optimizer): class FusedLamb(torch.optim.Optimizer):
"""Implements LAMB algorithm. Currently GPU-only. Requires DeepSpeed adapted Apex to be installed via """Implements the LAMB algorithm. Currently GPU-only.
``python setup.py install --cuda_ext --cpp_ext``.
For usage example please see, TODO DeepSpeed Tutorial LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
https://arxiv.org/abs/1904.00962 https://arxiv.org/abs/1904.00962
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups. parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): bias correction (default: True)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999)) running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8) numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Adam\: A Method for Stochastic Optimization: max_grad_norm (float, optional): value used to clip global grad norm
https://arxiv.org/abs/1412.6980 (default: 0.0)
.. _On the Convergence of Adam and Beyond: max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
https://openreview.net/forum?id=ryQu7f-RZ min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
""" """
def __init__(self, def __init__(self,
params, params,
...@@ -58,7 +49,14 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -58,7 +49,14 @@ class FusedLamb(torch.optim.Optimizer):
min_coeff=0.01, min_coeff=0.01,
amsgrad=False): amsgrad=False):
global fused_lamb_cuda global fused_lamb_cuda
fused_lamb_cuda = importlib.import_module("deepspeed_lamb_cuda") try:
fused_lamb_cuda = importlib.import_module(
"deepspeed.ops.lamb.fused_lamb_cuda")
except ImportError as err:
print(
"Unable to import Lamb cuda extension, please build DeepSpeed with cuda/cpp extensions."
)
raise err
if amsgrad: if amsgrad:
raise RuntimeError('FusedLamb does not support the AMSGrad variant.') raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
...@@ -153,9 +151,7 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -153,9 +151,7 @@ class FusedLamb(torch.optim.Optimizer):
if grad is None: if grad is None:
grad = p.grad.data grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError('FusedLamb does not support sparse gradients')
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
state = self.state[p] state = self.state[p]
......
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig
from .softmax import Softmax
from .matmul import MatMul
from .sparse_self_attention import SparseSelfAttention
from .bert_sparse_self_attention import BertSparseSelfAttention
from .sparse_attention_utils import SparseAttentionUtils
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
from torch import nn
from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig
class BertSparseSelfAttention(nn.Module):
"""Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373
For more information please see, TODO DeepSpeed Sparse Transformer.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
"""
def __init__(
self,
config,
# SparsityConfig parameters needs to be set accordingly
sparsity_config=FixedSparsityConfig(num_heads=4)):
"""Initialize the bert sparse self attention layer.
Note) you can use any of the provided sparsity configs or simply add yours!
Arguments:
config: required: Bert model config
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on FixedSparsityConfig class.
"""
super(BertSparseSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size,
config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.sparse_self_attention = SparseSelfAttention(sparsity_config)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
"""Applies forward phase of bert sparse self attention
Arguments:
hidden_states: required: hidde_states tensor of the bert model
attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
Return:
context_layer: a dense tensor containing attnetion context
"""
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
context_layer = self.sparse_self_attention(query_layer,
key_layer,
value_layer,
key_padding_mask=attention_mask)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
This diff is collapsed.
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
import warnings
try:
import triton
except ImportError:
warnings.warn("Unable to import triton, sparse attention will not be accessible")
import torch
import math
from deepspeed.ops.sparse_attention.trsrc import softmax_fwd, softmax_bwd
fwd_kernels = dict()
bwd_kernels = dict()
class _sparse_softmax(torch.autograd.Function):
bwd_kernels = dict()
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
idx = torch.arange(layout.sum())
head = layout.nonzero()[:, 0]
rows = layout.nonzero()[:, 1]
columns = layout.nonzero()[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table
offsets = offsets * 4 + 2 * sizes.numel()
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max())
@staticmethod
def make_kernel(cache,
src,
max_k,
dtype,
block,
apply_scale,
apply_rpe,
apply_kp_mask,
apply_attn_mask,
kp_mask_mode,
attn_mask_mode):
if max_k >= 32768:
raise NotImplementedError('Reductions larger than 32768 elements '\
'are not yet implemented')
num_warps = 4 if max_k < 512 else (8 if max_k < 2048 else 16)
pad = num_warps * 32 * 2
TN = (int(max_k) + pad - 1) // pad * pad
# just-in-time compile kernel
key = (block,
dtype,
num_warps,
TN,
apply_scale,
apply_rpe,
apply_kp_mask,
apply_attn_mask,
kp_mask_mode,
attn_mask_mode)
if key not in cache:
defines = {
'TM': [1],
'TN': [TN],
'TYPE': dtype,
'BLOCK': block,
'INFINITY': {
torch.float32: 'F32_INFINITY',
torch.float16: 'F16_INFINITY'
}[dtype]
}
if apply_scale:
defines['APPLY_SCALE'] = True
if apply_rpe:
defines['APPLY_RPE'] = True
if apply_kp_mask:
defines['APPLY_KP_MASK'] = True
if kp_mask_mode == 'mul':
defines['KP_MASK_MUL'] = True
if apply_attn_mask:
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, defines=defines, num_warps=[num_warps])
cache[key] = kernel
return cache[key]
@staticmethod
def forward(ctx,
x,
scale,
rpe,
key_padding_mask,
attn_mask,
kp_mask_mode,
attn_mask_mode,
spdims,
block,
lut,
num_blocks,
maxlut,
bench,
time):
apply_scale = False if scale == 1.0 else True
# handle None rpe
if rpe is None:
apply_rpe = False
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_rpe = True
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
# handle None key_padding_mask
if key_padding_mask is None:
apply_kp_mask = False
stride_zkpm = 0
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_kp_mask = True
stride_zkpm = key_padding_mask.stride(0)
# handle None attention_mask
if attn_mask is None:
apply_attn_mask = False
stride_zattnm = 0
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_attn_mask = True
stride_zattnm = attn_mask.stride(0)
# run kernel
kernel = _sparse_softmax.make_kernel(fwd_kernels,
softmax_fwd,
maxlut * block,
x.dtype,
block,
apply_scale,
apply_rpe,
apply_kp_mask,
apply_attn_mask,
kp_mask_mode,
attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.d('TM')), M]
# run kernel
time[0] = kernel(x, scale, lut, rpe, key_padding_mask, attn_mask,\
num_blocks, maxlut,\
x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe,\
stride_zkpm, stride_zattnm,\
grid=grid, bench=bench)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.apply_scale = apply_scale
ctx.apply_rpe = apply_rpe
ctx.apply_kp_mask = apply_kp_mask
ctx.apply_attn_mask = apply_attn_mask
ctx.kp_mask_mode = kp_mask_mode
ctx.attn_mask_mode = attn_mask_mode
return x
@staticmethod
def backward(ctx, dx):
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel
kernel = _sparse_softmax.make_kernel(bwd_kernels,
softmax_bwd,
ctx.maxlut * ctx.block,
x.dtype,
ctx.block,
ctx.apply_scale,
ctx.apply_rpe,
ctx.apply_kp_mask,
ctx.apply_attn_mask,
ctx.kp_mask_mode,
ctx.attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [
triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block,
opt.d('TM')),
M
]
kernel(x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class Softmax:
"""Block-Sparse Softmax class; this class computes softmax on a block sparse matrix. It is also able to apply either/all of the following masks:
- relative position embedding
- key padding mask
- attention mask
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
"""
sparse_softmax = _sparse_softmax.apply
def make_lut(self, device):
"""Generates the sparsity layout used in block-sparse softmax
"""
key = (device, )
if key not in self.lut_cache:
self.lut_cache[key] = _sparse_softmax.make_lut(self.layout,
self.block,
device)
return self.lut_cache[key]
def __init__(self, layout, block, bench=False):
"""Initialize the Block-Sparse Softmax class.
Arguments:
layout: required: sparsity layout tensor
block: required: an integer determining the block size.
bench: optional: set if you want to do benchmarking
"""
self.num_blocks = layout.sum()
self.spdims = layout.shape
self.layout = layout
self.block = block
self.bench = bench
self.lut_cache = dict()
def __call__(self,
x,
scale=1.,
rpe=None,
key_padding_mask=None,
attn_mask=None,
key_padding_mask_mode='add',
attn_mask_mode='add'):
"""Applies softmax on a Block-Sparse input tensor.
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
Arguments:
x: required: a block-sparse tensor that softmax is applied on it; computation will be in place and result will be returned in the same tensor
scale: optional: a float value; x values will be multiplied by this value before normalization. Default value is 1.0.
rpe: optional: a tensor same dimension as x that is used as relative position embedding
key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
Return:
x: a block-sparse tensor contains normalized input x using softmax; and masks applied if given
"""
time_y = [None]
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
if attn_mask is not None and attn_mask.dtype != x.dtype:
raise ValueError('Attention mask must be %s' % x.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = Softmax.sparse_softmax(x,
scale,
rpe,
key_padding_mask,
attn_mask,
key_padding_mask_mode,
attn_mask_mode,
self.spdims,
self.block,
lut,
self.num_blocks,
maxlut,
self.bench,
time_y)
self.time_y = time_y[0]
return x
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
from torch import nn
from torch.nn import functional as F
from deepspeed.ops.sparse_attention import BertSparseSelfAttention, SparsityConfig
'''
This file contains few utility functions to handle adapting pretrained model with sparse self-attention module.
'''
class SparseAttentionUtils:
"""This class provides some utility functions that are use integrating sparse attention into transformer models.
Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc.
"""
@staticmethod
def extend_position_embedding(model, max_position):
"""This function extends the position embedding weights of a model loaded from a checkpoint.
It assumes the new max position is bigger than the original max length.
Arguments:
model: required: a transformer model
max_position: required: an integer determining new position embedding size
Return:
model: updated model; in which position embedding weights have been extended based on new size
"""
if hasattr(model, 'bert'):
original_max_position = model.bert.embeddings.position_embeddings.weight.size(
0)
assert max_position > original_max_position
extend_multiples = max(1, max_position // original_max_position)
model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat(
extend_multiples,
1)
elif hasattr(model, 'roberta'):
# RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2
original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
original_max_position -= 2
extend_multiples = max(1, max_position // original_max_position)
assert max_position > original_max_position
max_position += 2
extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty(
max_position,
embed_size)
k = 2
for i in range(extend_multiples):
extended_position_embedding[k:(
k + original_max_position
)] = model.roberta.embeddings.position_embeddings.weight[2:]
k += original_max_position
model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding
else:
raise ValueError(
'Please extend \"extend_position_embedding\" function to support your model type. It currently only supports \"bert\" & \"roberta\"!'
)
model.config.max_position_embeddings = max_position
print(
f'Extended position embeddings to {original_max_position * extend_multiples}'
)
return model
@staticmethod
def update_tokenizer_model_max_length(tokenizer, max_position):
"""This function updates the position embedding length of a tokenizer to a new max position.
Arguments:
tokenizer: required: a transformer tokenizer
max_position: required: an integer determining new position embedding size
Return:
tokenizer: updated tokenizer; in which model maximum length has been extended based on new size
"""
tokenizer.model_max_length = max_position
tokenizer.init_kwargs['model_max_length'] = max_position
print(f'updated tokenizer model max imum length to {max_position}')
return tokenizer
@staticmethod
def replace_model_self_attention_with_sparse_self_attention(
model,
max_position,
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4)):
"""This function replaces the self attention layers in model encoder with sparse self attention.
It currently supports bert and roberta model and can be easily extended to any other models following similar steps here.
For sparsityConfig, refer to the config class.
Arguments:
model: required: a transformer model
max_position: required: an integer determining new position embedding size
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class
Return:
model: updated model; in which self attention layer has been repleaced with DeepSpeed Sparse Self Attention layer.
"""
if hasattr(model, 'bert'):
model.config.max_position_embeddings = max_position
replace_self_attention_layer_with_sparse_self_attention_layer(
model.config,
model.bert.encoder.layer,
sparsity_config)
elif hasattr(model, 'roberta'):
model.config.max_position_embeddings = max_position + 2
replace_self_attention_layer_with_sparse_self_attention_layer(
model.config,
model.roberta.encoder.layer,
sparsity_config)
else:
raise ValueError(
'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \
your model type. It currently only supports \"bert\" & \"roberta\"!'
)
return model
@staticmethod
def replace_self_attention_layer_with_sparse_self_attention_layer(
config,
layers,
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4)):
"""This function replaces the self attention layers in attention layer with sparse self attention.
For sparsityConfig, refer to the config class.
Arguments:
config: required: transformer model config
layers: required: transformer model attention layers
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class
Return:
layers: updated attention layers; in which self attention layers have been repleaced with DeepSpeed Sparse Self Attention layer.
"""
for layer in layers:
deepspeed_sparse_self_attn = BertSparseSelfAttention(config, sparsity_config)
deepspeed_sparse_self_attn.query = layer.attention.self.query
deepspeed_sparse_self_attn.key = layer.attention.self.key
deepspeed_sparse_self_attn.value = layer.attention.self.value
layer.attention.self = deepspeed_sparse_self_attn
return layers
@staticmethod
def pad_to_block_size(block_size,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
model_mbeddings):
"""This function pads input tokens and attention mask on sequence length dimension to be multiple of block size.
This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs.
Note)
1- instead of passing your embedding layer to this function, you can simply add this function to your model. It can be more simplified if given attention_mask and/or token_type_ids are none.
2- you need to call unpdad function before returning your model output to unpad the encoder sequence output.
Arguments:
block_size: required: an integer determining the block size of sparsity config.
pad_token_id: required: an integer determining the pad token from the model config; such as bert.config.pad_token_id.
input_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary
attention_mask: a torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
token_type_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
position_ids: a torch.LongTensor of shape [batch_size, sequence_length] with the indices of positions of each input sequence tokens in the position embeddings.
inputs_embeds: an optional torch.FloatTensor of shape [batch_size, sequence_length, hidden_size] that contains embedded representation and can be passed instead of input_ids directly.
model_embeddings: an optional object. If inputs_embeds are not none, this will be your model embeddings such as BertEmbeddings from your model such as BertModel. You can move this function inside your model and use self.embeddings instead of passing this parameter.
Return:
pad_len: an integer determining how much inputs have been padded to transfer sequence length dimension to multiple of block size.
input_ids: if input_ids are not none padded input_ids otherwise none.
attention_mask: if attention_mask is not none padded attention_mask otherwise none.
token_type_ids: if token_type_ids are not none padded token_type_ids otherwise none.
position_ids: if position_ids are not none padded position_ids otherwise none.
inputs_embeds: if inputs_embeds are not none padded inputs_embeds otherwise none.
"""
batch_size, seq_len = input_ids.shape if input_ids is not None else inputs_embeds.shape[:-1]
pad_len = (block_size - seq_len % block_size) % block_size
if pad_len > 0:
if inputs_embeds is not None:
pad_input_ids = inputs_embeds.new_full((batch_size,
pad_len),
pad_token_id,
dtype=torch.long)
pad_inputs_embeds = model_embeddings(pad_input_ids)
inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2)
# may not be needed as input_ids are not used if inputs_embeds are given
if input_ids is not None:
input_ids = F.pad(input_ids, (0, pad_len), value=pad_token_id)
if position_ids is not None:
# pad position_id with pad_token_id
position_ids = F.pad(position_ids, (0, pad_len), value=pad_token_id)
# pad attention mask without attention on the padding tokens
attention_mask = F.pad(attention_mask, (0, pad_len), value=False)
# pad token_type_ids with token_type_id = 0
token_type_ids = F.pad(token_type_ids, (0, pad_len), value=0)
return pad_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
@staticmethod
def unpad_sequence_output(pad_len, sequence_output):
"""This function unpads sequence output if inputs of the model were padded.
This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
It needs to be called in your model, such as BertModel, right before you return the model outputs.
Arguments:
pad_len: required: an integer determining how much model inputs have been padded to transfer sequence length dimension to multiple of block size.
sequence_output: required: sequence output of the encoder layer.
Return:
sequence_output: unpaded sequence output of the encoder layer.
"""
if (pad_len > 0):
sequence_output = sequence_output[:, :-pad_len]
return sequence_output
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import torch.nn as nn
from torch.nn.functional import *
import torch
from collections import namedtuple
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig
import sys
class SparseSelfAttention(nn.Module):
"""Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
For more information please see, TODO DeepSpeed Sparse Transformer.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
"""
def __init__(
self,
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul'):
"""Initialize the sparse self attention layer.
Arguments:
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class.
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
"""
super().__init__()
# sparsity information
self.sparsity_config = sparsity_config
# mask modes
self.key_padding_mask_mode = key_padding_mask_mode
self.attn_mask_mode = attn_mask_mode
ops = dict()
# add to cache
def get_ops(self, H, L):
import sys
if L not in SparseSelfAttention.ops:
sparsity_layout = self.sparsity_config.make_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout,
self.sparsity_config.block,
'sdd',
trans_a=False,
trans_b=True)
sparse_dot_dsd_nn = MatMul(sparsity_layout,
self.sparsity_config.block,
'dsd',
trans_a=False,
trans_b=False)
sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt,
sparse_dot_dsd_nn,
sparse_softmax)
return SparseSelfAttention.ops[L]
def transpose_key_for_scores(self, x, L):
bsz, num_heads, seq_len, head_dim = x.size()
if seq_len != L:
return x.permute(0, 1, 3, 2)
return x
def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False):
x = x.type(qtype)
if is_key_padding_mask:
xdim = x.dim()
for d in range(xdim - 1, 0, -1):
x = x.squeeze(dim=d)
return x
return x.squeeze()
# forward pass
def forward(self,
query,
key,
value,
rpe=None,
key_padding_mask=None,
attn_mask=None):
"""Applies forward phase of sparse self attention
Arguments:
query: required: query tensor
key: required: key tensor
value: required: value tensor
rpe: optional: a tensor same dimension as x that is used as relative position embedding
key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
Return:
attn_output: a dense tensor containing attnetion context
"""
bsz, num_heads, tgt_len, head_dim = query.size()
# transpose back key if it is already transposed
key = self.transpose_key_for_scores(key, tgt_len)
# check that operation is supported
if query.shape != key.shape or key.shape != value.shape:
raise NotImplementedError('only self-attention is supported for now')
# squeeze key_padding_mask if it is given
if key_padding_mask is not None:
key_padding_mask = self.transpose_mask_for_sparse(query.dtype,
key_padding_mask,
is_key_padding_mask=True)
# squeeze attn_mask if it is given
if attn_mask is not None:
attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask)
# cache look-up table computations etc
sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len)
scaling = float(head_dim)**-0.5
# attention scores
attn_output_weights = sparse_dot_sdd_nt(query, key)
attn_output_weights = sparse_softmax(
attn_output_weights,
scale=scaling,
rpe=rpe,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
key_padding_mask_mode=self.key_padding_mask_mode,
attn_mask_mode=self.attn_mask_mode)
# outputs
attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
return attn_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment