Unverified Commit e5bbc2e5 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Sparse attn + ops/runtime refactor + v0.3.0 (#343)



* Sparse attn + ops/runtime refactor + v0.3.0
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
Co-authored-by: default avatarArash Ashari <arashari@microsoft.com>
parent 838f53b7
This diff is collapsed.
import sys
import os
def _build_file_index(directory, suffix='.tr'):
"""Build an index of source files and their basenames in a given directory.
Args:
directory (string): the directory to index
suffix (string): index files with this suffix
Returns:
list: A list of tuples of the form [(basename, absolute path), ...]
"""
index = []
for fname in os.listdir(directory):
if fname.endswith(suffix):
basename = fname[:fname.rfind(suffix)] # strip the suffix
path = os.path.join(directory, fname)
index.append((basename, path))
return index
# Go over all local source files and parse them as strings
_module = sys.modules[_build_file_index.__module__]
_directory = os.path.dirname(os.path.realpath(__file__))
for name, fname in _build_file_index(_directory):
with open(fname, 'r') as fin:
setattr(_module, name, fin.read())
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16),
TYPE* C __noalias __aligned(16),
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc __multipleof(8),
long stride_za __multipleof(8),
long stride_zb __multipleof(8),
long stride_zc __multipleof(8),
long stride_ha __multipleof(8),
long stride_hb __multipleof(8),
long stride_hc __multipleof(8),
int DS0, int DS1,
int SDD_K __multipleof(16),
int SDD_off_width,
int* lut, int* locks, int nlocks) {
/* ---------------- */
/* Prologue */
/* ---------------- */
// program ids
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pidz = get_program_id(2);
#ifdef SDD
// load LUT header
pid1 = pid1 + SDD_off_width;
int blockidm[TM] = (0 ... TM) / BLOCK;
int blockidn[TN] = (0 ... TN) / BLOCK;
int offlutm[TM] = blockidm*(TN/BLOCK)*4;
int offlutn[TN] = blockidn*4;
int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4;
int z = *(header + 0);
int i[TM] = *(header + 1 + offlutm);
int j[TN] = *(header + 2 + offlutn);
int AS1 = SDD_K / TZ;
int lockid = select(TZ > 1, 1, 0);
int offka = pid0 * AS1;
int offkb = pid0 * AS1;
int offmc = 0;
int offnc = 0;
int offpa = 0;
int offpb = 0;
int maxid = TZ;
int offhc = 0;
int offha = z;
int offhb = z;
int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK);
int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK);
#else
// load LUT header
int *header = lut + pid0 * 6;
int offset = *(header + 0);
int AS1 = *(header + 1);
int column = *(header + 2);
int depth = *(header + 3);
int lockid = *(header + 4);
int maxid = *(header + 5);
int *pinc = lut + offset;
int offhc = depth;
#ifdef DSD
// output offset
int offnc = pid1 * TN;
int offmc = column * TM;
int offpc = 0;
// dense input offset
int offnb = pid1 * TN;
int offkb __multipleof(8) = *pinc;
int offpb = 0;
// sparse input offset
int offma = 0;
int offka = 0;
long offpa __multipleof(8) = *(pinc + 1);
offpa = offpa * BLOCK * BLOCK;
int offha = 0;
int offhb = depth;
#endif
#ifdef DDS
// output offset
int offmc = pid1 * TM;
int offnc = column * TN;
int offpc = 0;
// dense input offset
int offma = pid1 * TM;
int offka __multipleof(8) = *pinc;
int offpa = 0;
// sparse input offset
int offnb = 0;
int offkb = 0;
long offpb __multipleof(8) = *(pinc + 1);
offpb = offpb * BLOCK * BLOCK;
int offha = depth;
int offhb = 0;
#endif
int ram[TM] = offma + 0 ... TM;
int rbn[TN] = offnb + 0 ... TN;
#endif
// initialize a, b pointers
int rka[TK] = offka + 0 ... TK;
int rkb[TK] = offkb + 0 ... TK;
TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK;
TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK;
// pre-fetch
#ifdef DDS
bool checkam[TM, TK] = ram[:, newaxis] < DS0;
#else
bool checkam[TM, TK] = AS1 > 0;
#endif
#ifdef DSD
bool checkbn[TK, TN] = rbn[newaxis, :] < DS0;
#else
bool checkbn[TK, TN] = AS1 > 0;
#endif
TYPE a[TM, TK] = checkam ? *pa : 0;
TYPE b[TK, TN] = checkbn ? *pb : 0;
/* ---------------- */
/* Inner Loop */
/* ---------------- */
// create result tile
float acc[TM, TN] = 0;
int step = TK;
for(int k = AS1; k > 0; k -= step) {
acc += a @ b;
// update pointers
#ifdef SDD
int inc_a = TK * STRIDE_AK;
int inc_b = TK * STRIDE_BK;
#else
pinc += 2;
#ifdef DSD
int inc_b __multipleof(8) = *pinc;
int inc_a __multipleof(8) = *(pinc + 1);
inc_b = inc_b * STRIDE_BK;
#endif
#ifdef DDS
int inc_a __multipleof(8) = *pinc;
int inc_b __multipleof(8) = *(pinc + 1);
inc_a = inc_a * STRIDE_AK;
#endif
#endif
pa += inc_a;
pb += inc_b;
// pre-fetch
bool checkak[TM, TK] = k > TK;
bool checkbk[TK, TN] = k > TK;
bool checka[TM, TK] = checkam && checkak;
bool checkb[TK, TN] = checkbk && checkbn;
a = *?(checka)pa;
b = *?(checkb)pb;
}
TYPE c[TM, TN] = acc;
/* ---------------- */
/* Epilogue */
/* ---------------- */
// initialize c pointers
#ifdef SDD
bool checkc[TM, TN] = 1;
// rematerialize
int rr_blockidm[TM] = (0 ... TM) / BLOCK;
int rr_blockidn[TN] = (0 ... TN) / BLOCK;
int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4;
int rr_offlutn[TN] = rr_blockidn*4;
int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :];
int bkid[TM, TN] = *(header + off_bkid);
long offpc[TM, TN] = bkid * BLOCK * BLOCK;
// range within blocks
int rcm[TM] = (0 ... TM) % BLOCK;
int rcn[TN] = (0 ... TN) % BLOCK;
#else
int rcm[TM] = offmc + 0 ... TM;
int rcn[TN] = offnc + 0 ... TN;
#ifdef DSD
bool checkc[TM, TN] = rcn[newaxis, :] < DS0;
#endif
#ifdef DDS
bool checkc[TM, TN] = rcm[:, newaxis] < DS0;
#endif
#endif
TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN;
// write-back directly
if(lockid == 0) {
*?(checkc) pc = c;
}
// accumulate partial result using spin-locks
else {
int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1;
int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks;
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % maxid);
atomic_xchg(plock, 0);
}
}
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
__global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
float scale,
TYPE* DX __readonly __noalias __aligned(16),
int* LUT,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zdx __multipleof(BLOCK)) {
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
// bounds checking on lut
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// initialize pointers to block-sparse input
long blockid[TN] = *(LUT + offset + rbmn*4);
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
TYPE* pdx[TN] = DX + pidz * stride_zdx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
// compute fused softmax backward
TYPE x[TN] = check ? *px : 0;
TYPE dx[TN] = check ? *pdx : 0;
float Fdx[TN] = dx;
float Fx[TN] = x;
float Fxdx[TN] = Fdx*Fx;
float Fxdxsum = Fxdx[+];
float Fy[TN] = Fx * (Fdx - Fxdxsum) * scale;
TYPE y[TN] = Fy;
// write-back
*? (check)pdx = y;
}
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
__global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
float scale,
int *LUT __readonly __noalias __aligned(16),
TYPE *RPE __readonly __noalias __aligned(16),
TYPE *KP_M __readonly __noalias __aligned(16),
TYPE *ATTN_M __readonly __noalias __aligned(16),
int num_blocks,
int sizemax,
long stride_zx __multipleof(BLOCK),
long stride_zrpe __multipleof(BLOCK),
int stride_hrpe __multipleof(BLOCK),
int stride_srpe __multipleof(BLOCK),
int stride_zkpm __multipleof(BLOCK),
int stride_zattnm __multipleof(BLOCK)){
int pidhm = get_program_id(0);
int pidz = get_program_id(1);
// create index ranges
int rxm = pidhm % BLOCK;
int rbm = pidhm / BLOCK;
int rxn[TN] = (0 ... TN) % BLOCK;
int rbn[TN] = (0 ... TN) / BLOCK;
// extract information from look-up table
int* header = LUT + rbm * 2;
int size = *(header + 0);
int offset = *(header + 1);
bool check[TN] = rbn < size;
int rbmn[TN] = check ? rbn : size - 1;
// block id and column id
long blockid [TN] = *(LUT + offset + rbmn*4 + 0);
long columnid[TN] = *(LUT + offset + rbmn*4 + 1);
long rowid [TN] = *(LUT + offset + rbmn*4 + 2);
long headid [TN] = *(LUT + offset + rbmn*4 + 3);
// pointers to X
TYPE* px[TN] = X + pidz * stride_zx
+ blockid * BLOCK * BLOCK
+ rxm * BLOCK
+ rxn;
#ifdef APPLY_RPE
// pointers to relative position embedding
TYPE* prpe[TN] = RPE + pidz * stride_zrpe
+ headid * stride_hrpe
+ columnid * BLOCK
+ rowid * BLOCK * stride_srpe
+ rxm * stride_srpe
+ rxn;
#endif
#ifdef APPLY_KP_MASK
// pointers to key padding mask
TYPE* pkp_m[TN] = KP_M + pidz * stride_zkpm
+ columnid * BLOCK
+ rxn;
#endif
#ifdef APPLY_ATTN_MASK
// pointers to attention mask
TYPE* pattn_m[TN] = ATTN_M + columnid * BLOCK
+ rowid * BLOCK * stride_zattnm
+ rxm * stride_zattnm
+ rxn;
#endif
// load input
TYPE x[TN] = check ? *px : -INFINITY;
#ifdef APPLY_RPE
// load relative position embedding
TYPE rpe[TN] = check ? *prpe : 0;
#endif
#ifdef APPLY_KP_MASK
// load key-padding mask
TYPE kp_m[TN] = check ? *pkp_m : -INFINITY;
#endif
#ifdef APPLY_ATTN_MASK
// load attention mask
TYPE attn_m[TN] = check ? *pattn_m : -INFINITY;
#endif
// compute softmax in float
#ifdef APPLY_RPE
float Frpe[TN] = rpe;
#endif
#ifdef APPLY_KP_MASK
float Fkp_m[TN] = kp_m;
#endif
#ifdef APPLY_ATTN_MASK
float Fattn_m[TN] = attn_m;
#endif
#ifdef KP_MASK_MUL
Fkp_m = (Fkp_m == 0) ? (float[TN])-INFINITY : 0;
#endif
#ifdef ATTN_MASK_MUL
Fattn_m = (Fattn_m == 0) ? (float[TN])-INFINITY : 0;
#endif
float Fx[TN] = x;
#ifdef APPLY_SCALE
Fx = Fx * scale; // apply scale
#endif
#ifdef APPLY_RPE
Fx = Fx + Frpe; // apply relative position embedding
#endif
#ifdef APPLY_KP_MASK
Fx = Fx + Fkp_m; // apply key padding mask
#endif
#ifdef APPLY_ATTN_MASK
Fx = Fx + Fattn_m; // apply attention mask
#endif
float Fxmax = Fx[max];
float Fy[TN] = exp(Fx - Fxmax);
float Fysum = (check ? Fy : 0)[+];
// write-back in half/float
TYPE y[TN] = Fy;
TYPE ysum = Fysum;
*?(check)px = y / ysum;
}
from deepspeed.ops.transformer.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from torch import nn
from torch.autograd import Function
import torch
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
import json
import math
import deepspeed_transformer_cuda as ds_transformer_cuda
import deepspeed_stochastic_transformer_cuda as ds_stochastic_transformer_cuda
import importlib
import torch
from torch import nn
from torch.autograd import Function
# Cuda modules will be imported if needed
transformer_cuda_module = None
stochastic_transformer_cuda_module = None
class TransformerConfig():
......@@ -159,7 +165,7 @@ class DeepSpeedTransformerFunction(Function):
if bsz > config.batch_size:
raise ValueError('Input batch size exceeds the limit.')
cuda_module = ds_stochastic_transformer_cuda if config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
(output,
......@@ -321,7 +327,7 @@ class DeepSpeedTransformerFunction(Function):
norm_w,
norm_b) = ctx.saved_tensors
cuda_module = ds_stochastic_transformer_cuda if ctx.config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module
backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32
(grad_input,
......@@ -457,8 +463,22 @@ class DeepSpeedTransformerLayer(nn.Module):
self.norm_w = initial_weights[7]
self.norm_b = initial_biases[7]
# Import cuda modules if needed
global transformer_cuda_module, stochastic_transformer_cuda_module
if transformer_cuda_module is None or stochastic_transformer_cuda_module is None:
try:
transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.transformer_cuda")
stochastic_transformer_cuda_module = importlib.import_module(
"deepspeed.ops.transformer.stochastic_transformer_cuda")
except ImportError as err:
print(
"Unable to import transformer cuda extension, please build DeepSpeed with cuda/cpp extensions."
)
raise err
# create the layer in cuda kernels.
cuda_module = ds_stochastic_transformer_cuda if self.config.stochastic_mode else ds_transformer_cuda
cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module
create_layer_func = cuda_module.create_transformer_layer_fp16 if self.config.fp16 else cuda_module.create_transformer_layer_fp32
create_layer_func(self.config.layer_id,
......
......@@ -13,16 +13,17 @@ b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import copy
import torch.distributed as dist
import torch
import contextlib
import torch.distributed as dist
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
import torch.distributed as dist
from deepspeed.pt.deepspeed_config import DeepSpeedConfig
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
#DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
......
......@@ -3,7 +3,7 @@ Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
from deepspeed.pt.deepspeed_config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param
#########################################
# DeepSpeed Activation Checkpointing
......
......@@ -6,12 +6,12 @@ Licensed under the MIT license.
import torch
import json
import copy
from deepspeed.pt.deepspeed_constants import *
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.pt.deepspeed_config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.pt.deepspeed_zero_config import DeepSpeedZeroConfig
from deepspeed.pt.deepspeed_checkpointing_config import DeepSpeedActivationCheckpointingConfig
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
......@@ -158,6 +158,177 @@ def get_gradient_clipping(param_dict):
return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT)
def get_sparse_attention(param_dict):
if SPARSE_ATTENTION in param_dict.keys():
sparsity = param_dict[SPARSE_ATTENTION]
mode = get_sparse_attention_mode(sparsity)
if (mode == SPARSE_DENSE_MODE):
return get_sparse_dense_config(sparsity)
elif (mode == SPARSE_FIXED_MODE):
return get_sparse_fixed_config(sparsity)
elif (mode == SPARSE_VARIABLE_MODE):
return get_sparse_variable_config(sparsity)
elif (mode == SPARSE_BIGBIRD_MODE):
return get_sparse_bigbird_config(sparsity)
elif (mode == SPARSE_BSLONGFORMER_MODE):
return get_sparse_bslongformer_config(sparsity)
else:
raise NotImplementedError(
f'Given sparsity mode, {mode}, has not been implemented yet!')
else:
return None
def get_sparse_dense_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block}
def get_sparse_fixed_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_local_blocks = get_scalar_param(sparsity,
SPARSE_NUM_LOCAL_BLOCKS,
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT)
num_global_blocks = get_scalar_param(sparsity,
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
attention = get_scalar_param(sparsity,
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
num_differnt_global_patterns = get_scalar_param(
sparsity,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT)
return {
SPARSE_MODE: SPARSE_FIXED_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention,
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_differnt_global_patterns
}
def get_sparse_variable_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_random_blocks = get_scalar_param(sparsity,
SPARSE_NUM_RANDOM_BLOCKS,
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
local_window_blocks = get_scalar_param(sparsity,
SPARSE_LOCAL_WINDOW_BLOCKS,
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
attention = get_scalar_param(sparsity,
SPARSE_ATTENTION_TYPE,
SPARSE_ATTENTION_TYPE_DEFAULT)
horizontal_global_attention = get_scalar_param(
sparsity,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT)
return {
SPARSE_MODE: SPARSE_VARIABLE_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices,
SPARSE_ATTENTION_TYPE: attention,
SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention
}
def get_sparse_bigbird_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_random_blocks = get_scalar_param(sparsity,
SPARSE_NUM_RANDOM_BLOCKS,
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
num_global_blocks = get_scalar_param(sparsity,
SPARSE_NUM_GLOBAL_BLOCKS,
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT)
return {
SPARSE_MODE: SPARSE_BIGBIRD_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks
}
def get_sparse_bslongformer_config(sparsity):
block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT)
different_layout_per_head = get_scalar_param(
sparsity,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT)
num_sliding_window_blocks = get_scalar_param(
sparsity,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT)
global_block_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_INDICES,
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT)
global_block_end_indices = get_scalar_param(sparsity,
SPARSE_GLOBAL_BLOCK_END_INDICES,
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT)
return {
SPARSE_MODE: SPARSE_BSLONGFORMER_MODE,
SPARSE_BLOCK: block,
SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head,
SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks,
SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices,
SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices
}
def get_sparse_attention_mode(param_dict):
if SPARSE_MODE in param_dict.keys():
return param_dict[SPARSE_MODE]
else:
return SPARSE_MODE_DEFAULT
def get_sparse_attention_type(param_dict):
if SPARSE_ATTENTION_TYPE in param_dict.keys():
return param_dict[SPARSE_ATTENTION_TYPE]
else:
return SPARSE_ATTENTION_TYPE_DEFAULT
def get_optimizer_name(param_dict):
if OPTIMIZER in param_dict.keys() and \
TYPE in param_dict[OPTIMIZER].keys():
......@@ -358,6 +529,8 @@ class DeepSpeedConfig(object):
self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
self.sparse_attention = get_sparse_attention(param_dict)
def _batch_assertion(self):
train_batch = self.train_batch_size
......
......@@ -17,6 +17,42 @@ ROUTE_ENCODE = "encode"
TRAIN_BATCH_SIZE = "train_batch_size"
TRAIN_BATCH_SIZE_DEFAULT = None
#############################################
# Sparse attention
#############################################
SPARSE_ATTENTION = "sparse_attention"
SPARSE_DENSE_MODE = "dense"
SPARSE_FIXED_MODE = "fixed"
SPARSE_VARIABLE_MODE = "variable"
SPARSE_BIGBIRD_MODE = "bigbird"
SPARSE_BSLONGFORMER_MODE = "bslongformer"
SPARSE_MODE = "mode"
SPARSE_MODE_DEFAULT = SPARSE_FIXED_MODE
SPARSE_BLOCK = "block"
SPARSE_BLOCK_DEFAULT = 16
SPARSE_DIFFERENT_LAYOUT_PER_HEAD = "different_layout_per_head"
SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT = False
SPARSE_NUM_LOCAL_BLOCKS = "num_local_blocks"
SPARSE_NUM_LOCAL_BLOCKS_DEFAULT = 4
SPARSE_NUM_GLOBAL_BLOCKS = "num_global_blocks"
SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT = 1
SPARSE_ATTENTION_TYPE = "attention"
SPARSE_ATTENTION_TYPE_DEFAULT = "bidirectional"
SPARSE_HORIZONTAL_GLOBAL_ATTENTION = "horizontal_global_attention"
SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT = False
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS = "num_differnt_global_patterns"
SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT = 1
SPARSE_NUM_RANDOM_BLOCKS = "num_random_blocks"
SPARSE_NUM_RANDOM_BLOCKS_DEFAULT = 0
SPARSE_LOCAL_WINDOW_BLOCKS = "local_window_blocks"
SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT = [4]
SPARSE_GLOBAL_BLOCK_INDICES = "global_block_indices"
SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT = [0]
SPARSE_GLOBAL_BLOCK_END_INDICES = "global_block_end_indices"
SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT = None
SPARSE_NUM_SLIDING_WINDOW_BLOCKS = "num_sliding_window_blocks"
SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT = 3
#############################################
# Optimizer and lr scheduler
#############################################
......
......@@ -2,36 +2,35 @@
Copyright 2019 The Microsoft DeepSpeed Team
'''
import torch
import os
import torch
import warnings
import torch.distributed as dist
from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from apex import amp
from tensorboardX import SummaryWriter
from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer
from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.pt.log_utils import logger
import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import \
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT, \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.ops.lamb import FusedLamb
import deepspeed.pt.deepspeed_lr_schedules as lr_schedules
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId"
......@@ -92,7 +91,7 @@ def print_configuration(args, name):
logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg)))
class DeepSpeedLight(Module):
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training.
"""
def __init__(self,
......@@ -106,7 +105,7 @@ class DeepSpeedLight(Module):
dist_init_required=None,
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()
super(DeepSpeedEngine, self).__init__()
self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
......
......@@ -9,9 +9,9 @@ import torch
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.pt.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.pt.log_utils import logger
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
class FP16_Optimizer(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment