Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
...@@ -28,29 +31,11 @@ def num_warps(n): ...@@ -28,29 +31,11 @@ def num_warps(n):
return 16 return 16
@triton.heuristics({ @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[6] * meta['BLOCK'])})
'num_warps': lambda *args, @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[6] * meta['BLOCK'])})
**meta: num_warps(args[6] * meta['BLOCK'])
})
@triton.heuristics({
'TN': lambda *args,
**meta: next_power_of_2(args[6] * meta['BLOCK'])
})
@triton.jit @triton.jit
def _forward(X, def _forward(X, scale, LUT, RPE, KP_M, ATTN_M, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm,
scale, stride_zattnm, **meta):
LUT,
RPE,
KP_M,
ATTN_M,
sizemax,
stride_zx,
stride_zrpe,
stride_hrpe,
stride_srpe,
stride_zkpm,
stride_zattnm,
**meta):
TN = meta['TN'] TN = meta['TN']
BLOCK = meta['BLOCK'] BLOCK = meta['BLOCK']
pidhm = tl.program_id(0) pidhm = tl.program_id(0)
...@@ -102,14 +87,8 @@ def _forward(X, ...@@ -102,14 +87,8 @@ def _forward(X,
tl.store(px, x, mask=check) tl.store(px, x, mask=check)
@triton.heuristics({ @triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
'num_warps': lambda *args, @triton.heuristics({'TN': lambda *args, **meta: next_power_of_2(args[4]) * meta['BLOCK']})
**meta: num_warps(args[4] * meta['BLOCK'])
})
@triton.heuristics({
'TN': lambda *args,
**meta: next_power_of_2(args[4]) * meta['BLOCK']
})
@triton.jit @triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta): def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, **meta):
pidhm = tl.program_id(0) pidhm = tl.program_id(0)
...@@ -168,21 +147,8 @@ class _sparse_softmax(torch.autograd.Function): ...@@ -168,21 +147,8 @@ class _sparse_softmax(torch.autograd.Function):
return lut, int(sizes.max()) return lut, int(sizes.max())
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, x, scale, rpe, key_padding_mask, attn_mask, kp_mask_mode, attn_mask_mode, spdims, block, lut,
x, num_blocks, maxlut, bench, time):
scale,
rpe,
key_padding_mask,
attn_mask,
kp_mask_mode,
attn_mask_mode,
spdims,
block,
lut,
num_blocks,
maxlut,
bench,
time):
apply_scale = False if scale == 1.0 else True apply_scale = False if scale == 1.0 else True
...@@ -251,14 +217,7 @@ class _sparse_softmax(torch.autograd.Function): ...@@ -251,14 +217,7 @@ class _sparse_softmax(torch.autograd.Function):
# run kernel # run kernel
M = x.shape[0] M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M] grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, _backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
ctx.scale,
dx,
lut,
ctx.maxlut,
x.stride(0),
dx.stride(0),
BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
...@@ -270,6 +229,7 @@ class Softmax: ...@@ -270,6 +229,7 @@ class Softmax:
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509 For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
""" """
def sparse_softmax(*args, **kwargs): def sparse_softmax(*args, **kwargs):
return _sparse_softmax.apply(*args, **kwargs) return _sparse_softmax.apply(*args, **kwargs)
...@@ -278,9 +238,7 @@ class Softmax: ...@@ -278,9 +238,7 @@ class Softmax:
""" """
key = (device, ) key = (device, )
if key not in self.lut_cache: if key not in self.lut_cache:
self.lut_cache[key] = _sparse_softmax.make_lut(self.layout, self.lut_cache[key] = _sparse_softmax.make_lut(self.layout, self.block, device)
self.block,
device)
return self.lut_cache[key] return self.lut_cache[key]
def __init__(self, layout, block, bench=False): def __init__(self, layout, block, bench=False):
...@@ -332,19 +290,7 @@ class Softmax: ...@@ -332,19 +290,7 @@ class Softmax:
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype: if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype) raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device) lut, maxlut = self.make_lut(x.device)
x = Softmax.sparse_softmax(x, x = Softmax.sparse_softmax(x, scale, rpe, key_padding_mask, attn_mask, key_padding_mask_mode, attn_mask_mode,
scale, self.spdims, self.block, lut, self.num_blocks, maxlut, self.bench, time_y)
rpe,
key_padding_mask,
attn_mask,
key_padding_mask_mode,
attn_mask_mode,
self.spdims,
self.block,
lut,
self.num_blocks,
maxlut,
self.bench,
time_y)
self.time_y = time_y[0] self.time_y = time_y[0]
return x return x
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
...@@ -15,6 +16,7 @@ class SparseAttentionUtils: ...@@ -15,6 +16,7 @@ class SparseAttentionUtils:
Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc. Such utilities include extending position embeddings, replacing current self-attention layer with sparse attention, padding sequences to multiple of block size, etc.
""" """
@staticmethod @staticmethod
def extend_position_embedding(model, max_position): def extend_position_embedding(model, max_position):
"""This function extends the position embedding weights of a model loaded from a checkpoint. """This function extends the position embedding weights of a model loaded from a checkpoint.
...@@ -28,13 +30,11 @@ class SparseAttentionUtils: ...@@ -28,13 +30,11 @@ class SparseAttentionUtils:
""" """
if hasattr(model, 'bert'): if hasattr(model, 'bert'):
original_max_position = model.bert.embeddings.position_embeddings.weight.size( original_max_position = model.bert.embeddings.position_embeddings.weight.size(0)
0)
assert max_position > original_max_position assert max_position > original_max_position
extend_multiples = max(1, max_position // original_max_position) extend_multiples = max(1, max_position // original_max_position)
model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat( model.bert.embeddings.position_embeddings.weight.data = model.bert.embeddings.position_embeddings.weight.repeat(
extend_multiples, extend_multiples, 1)
1)
elif hasattr(model, 'roberta'): elif hasattr(model, 'roberta'):
# RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2 # RoBERTa has positions 0 & 1 reserved, so embedding size is max position + 2
original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape original_max_position, embed_size = model.roberta.embeddings.position_embeddings.weight.shape
...@@ -43,13 +43,11 @@ class SparseAttentionUtils: ...@@ -43,13 +43,11 @@ class SparseAttentionUtils:
assert max_position > original_max_position assert max_position > original_max_position
max_position += 2 max_position += 2
extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty( extended_position_embedding = model.roberta.embeddings.position_embeddings.weight.new_empty(
max_position, max_position, embed_size)
embed_size)
k = 2 k = 2
for i in range(extend_multiples): for i in range(extend_multiples):
extended_position_embedding[k:( extended_position_embedding[k:(
k + original_max_position k + original_max_position)] = model.roberta.embeddings.position_embeddings.weight[2:]
)] = model.roberta.embeddings.position_embeddings.weight[2:]
k += original_max_position k += original_max_position
model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding model.roberta.embeddings.position_embeddings.weight.data = extended_position_embedding
else: else:
...@@ -58,9 +56,7 @@ class SparseAttentionUtils: ...@@ -58,9 +56,7 @@ class SparseAttentionUtils:
) )
model.config.max_position_embeddings = max_position model.config.max_position_embeddings = max_position
print( print(f'Extended position embeddings to {original_max_position * extend_multiples}')
f'Extended position embeddings to {original_max_position * extend_multiples}'
)
return model return model
...@@ -102,21 +98,17 @@ class SparseAttentionUtils: ...@@ -102,21 +98,17 @@ class SparseAttentionUtils:
if hasattr(model, 'bert'): if hasattr(model, 'bert'):
model.config.max_position_embeddings = max_position model.config.max_position_embeddings = max_position
model.replace_self_attention_layer_with_sparse_self_attention_layer( model.replace_self_attention_layer_with_sparse_self_attention_layer(model.config, model.bert.encoder.layer,
model.config, sparsity_config)
model.bert.encoder.layer,
sparsity_config)
elif hasattr(model, 'roberta'): elif hasattr(model, 'roberta'):
model.config.max_position_embeddings = max_position + 2 model.config.max_position_embeddings = max_position + 2
model.replace_self_attention_layer_with_sparse_self_attention_layer( model.replace_self_attention_layer_with_sparse_self_attention_layer(model.config,
model.config, model.roberta.encoder.layer,
model.roberta.encoder.layer, sparsity_config)
sparsity_config)
else: else:
raise ValueError( raise ValueError(
'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \ 'Please extend \"update_model_self_attention_to_sparse_self_attention\" function to support \
your model type. It currently only supports \"bert\" & \"roberta\"!' your model type. It currently only supports \"bert\" & \"roberta\"!')
)
return model return model
@staticmethod @staticmethod
...@@ -148,14 +140,8 @@ class SparseAttentionUtils: ...@@ -148,14 +140,8 @@ class SparseAttentionUtils:
return layers return layers
@staticmethod @staticmethod
def pad_to_block_size(block_size, def pad_to_block_size(block_size, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds,
input_ids, pad_token_id, model_embeddings):
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
model_embeddings):
"""This function pads input tokens and attention mask on sequence length dimension to be multiple of block size. """This function pads input tokens and attention mask on sequence length dimension to be multiple of block size.
This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size. This is a requirement for Sparse Transformer in which the self attention layer works on sequences of length multiple of block size.
It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs. It needs to be called in your model, such as BertModel, right before you calculate the embedding outputs.
...@@ -187,10 +173,7 @@ class SparseAttentionUtils: ...@@ -187,10 +173,7 @@ class SparseAttentionUtils:
pad_len = (block_size - seq_len % block_size) % block_size pad_len = (block_size - seq_len % block_size) % block_size
if pad_len > 0: if pad_len > 0:
if inputs_embeds is not None: if inputs_embeds is not None:
pad_input_ids = inputs_embeds.new_full((batch_size, pad_input_ids = inputs_embeds.new_full((batch_size, pad_len), pad_token_id, dtype=torch.long)
pad_len),
pad_token_id,
dtype=torch.long)
pad_inputs_embeds = model_embeddings(pad_input_ids) pad_inputs_embeds = model_embeddings(pad_input_ids)
inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2) inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeds], dim=-2)
# may not be needed as input_ids are not used if inputs_embeds are given # may not be needed as input_ids are not used if inputs_embeds are given
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import torch.nn as nn import torch.nn as nn
import torch import torch
...@@ -15,6 +16,7 @@ class SparseSelfAttention(nn.Module): ...@@ -15,6 +16,7 @@ class SparseSelfAttention(nn.Module):
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
""" """
def __init__( def __init__(
self, self,
# SparsityConfig parameters needs to be set accordingly # SparsityConfig parameters needs to be set accordingly
...@@ -53,8 +55,7 @@ class SparseSelfAttention(nn.Module): ...@@ -53,8 +55,7 @@ class SparseSelfAttention(nn.Module):
if (L % self.sparsity_config.block != 0): if (L % self.sparsity_config.block != 0):
raise ValueError( raise ValueError(
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!' f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!')
)
num_blocks = L // self.sparsity_config.block num_blocks = L // self.sparsity_config.block
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor
...@@ -65,11 +66,7 @@ class SparseSelfAttention(nn.Module): ...@@ -65,11 +66,7 @@ class SparseSelfAttention(nn.Module):
from deepspeed.ops.sparse_attention.softmax import Softmax from deepspeed.ops.sparse_attention.softmax import Softmax
if L not in SparseSelfAttention.ops: if L not in SparseSelfAttention.ops:
sparsity_layout = self.get_layout(L) sparsity_layout = self.get_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout, sparse_dot_sdd_nt = MatMul(sparsity_layout, self.sparsity_config.block, 'sdd', trans_a=False, trans_b=True)
self.sparsity_config.block,
'sdd',
trans_a=False,
trans_b=True)
sparse_dot_dsd_nn = MatMul(sparsity_layout, sparse_dot_dsd_nn = MatMul(sparsity_layout,
self.sparsity_config.block, self.sparsity_config.block,
...@@ -79,9 +76,7 @@ class SparseSelfAttention(nn.Module): ...@@ -79,9 +76,7 @@ class SparseSelfAttention(nn.Module):
sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block) sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax)
sparse_dot_dsd_nn,
sparse_softmax)
return SparseSelfAttention.ops[L] return SparseSelfAttention.ops[L]
def transpose_key_for_scores(self, x, L): def transpose_key_for_scores(self, x, L):
...@@ -100,13 +95,7 @@ class SparseSelfAttention(nn.Module): ...@@ -100,13 +95,7 @@ class SparseSelfAttention(nn.Module):
return x.squeeze() return x.squeeze()
# forward pass # forward pass
def forward(self, def forward(self, query, key, value, rpe=None, key_padding_mask=None, attn_mask=None):
query,
key,
value,
rpe=None,
key_padding_mask=None,
attn_mask=None):
"""Applies forward phase of sparse self attention """Applies forward phase of sparse self attention
Arguments: Arguments:
...@@ -134,9 +123,7 @@ class SparseSelfAttention(nn.Module): ...@@ -134,9 +123,7 @@ class SparseSelfAttention(nn.Module):
# squeeze key_padding_mask if it is given # squeeze key_padding_mask if it is given
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = self.transpose_mask_for_sparse(query.dtype, key_padding_mask = self.transpose_mask_for_sparse(query.dtype, key_padding_mask, is_key_padding_mask=True)
key_padding_mask,
is_key_padding_mask=True)
# squeeze attn_mask if it is given # squeeze attn_mask if it is given
if attn_mask is not None: if attn_mask is not None:
...@@ -149,14 +136,13 @@ class SparseSelfAttention(nn.Module): ...@@ -149,14 +136,13 @@ class SparseSelfAttention(nn.Module):
# attention scores # attention scores
attn_output_weights = sparse_dot_sdd_nt(query, key) attn_output_weights = sparse_dot_sdd_nt(query, key)
attn_output_weights = sparse_softmax( attn_output_weights = sparse_softmax(attn_output_weights,
attn_output_weights, scale=scaling,
scale=scaling, rpe=rpe,
rpe=rpe, key_padding_mask=key_padding_mask,
key_padding_mask=key_padding_mask, attn_mask=attn_mask,
attn_mask=attn_mask, key_padding_mask_mode=self.key_padding_mask_mode,
key_padding_mask_mode=self.key_padding_mask_mode, attn_mask_mode=self.attn_mask_mode)
attn_mask_mode=self.attn_mask_mode)
# outputs # outputs
attn_output = sparse_dot_dsd_nn(attn_output_weights, value) attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import torch import torch
import random import random
...@@ -10,6 +11,7 @@ class SparsityConfig: ...@@ -10,6 +11,7 @@ class SparsityConfig:
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`. """Abstract Configuration class to store `sparsity configuration of a self attention layer`.
It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality. It contains shared property of different block-sparse sparsity patterns. However, each class needs to extend it based on required property and functionality.
""" """
def __init__(self, num_heads, block=16, different_layout_per_head=False): def __init__(self, num_heads, block=16, different_layout_per_head=False):
"""Initialize the Sparsity Pattern Config. """Initialize the Sparsity Pattern Config.
...@@ -37,9 +39,7 @@ class SparsityConfig: ...@@ -37,9 +39,7 @@ class SparsityConfig:
""" """
if (seq_len % self.block != 0): if (seq_len % self.block != 0):
raise ValueError( raise ValueError(f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!')
f'Sequence Length, {seq_len}, needs to be dividable by Block size {self.block}!'
)
num_blocks = seq_len // self.block num_blocks = seq_len // self.block
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64) layout = torch.zeros((self.num_heads, num_blocks, num_blocks), dtype=torch.int64)
...@@ -64,6 +64,7 @@ class DenseSparsityConfig(SparsityConfig): ...@@ -64,6 +64,7 @@ class DenseSparsityConfig(SparsityConfig):
"""Configuration class to store `Dense` configuration. """Configuration class to store `Dense` configuration.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
""" """
def __init__(self, num_heads, block=16, different_layout_per_head=False): def __init__(self, num_heads, block=16, different_layout_per_head=False):
"""Initialize the Dense Sparsity Pattern Config. """Initialize the Dense Sparsity Pattern Config.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension. In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and comprehension.
...@@ -96,6 +97,7 @@ class FixedSparsityConfig(SparsityConfig): ...@@ -96,6 +97,7 @@ class FixedSparsityConfig(SparsityConfig):
For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. For more details about this sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
""" """
def __init__(self, def __init__(self,
num_heads, num_heads,
block=16, block=16,
...@@ -131,14 +133,11 @@ class FixedSparsityConfig(SparsityConfig): ...@@ -131,14 +133,11 @@ class FixedSparsityConfig(SparsityConfig):
self.num_global_blocks = num_global_blocks self.num_global_blocks = num_global_blocks
if (attention != 'unidirectional' and attention != 'bidirectional'): if (attention != 'unidirectional' and attention != 'bidirectional'):
raise NotImplementedError( raise NotImplementedError('only \"uni/bi-directional\" attentions are supported for now!')
'only \"uni/bi-directional\" attentions are supported for now!')
self.attention = attention self.attention = attention
if (attention != 'bidirectional' and horizontal_global_attention): if (attention != 'bidirectional' and horizontal_global_attention):
raise ValueError( raise ValueError('only \"bi-directional\" attentions can support horizontal global attention!')
'only \"bi-directional\" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention self.horizontal_global_attention = horizontal_global_attention
if (num_different_global_patterns > 1 and not different_layout_per_head): if (num_different_global_patterns > 1 and not different_layout_per_head):
...@@ -166,9 +165,7 @@ class FixedSparsityConfig(SparsityConfig): ...@@ -166,9 +165,7 @@ class FixedSparsityConfig(SparsityConfig):
for i in range(0, num_blocks, self.num_local_blocks): for i in range(0, num_blocks, self.num_local_blocks):
end = min(i + self.num_local_blocks, num_blocks) end = min(i + self.num_local_blocks, num_blocks)
for row in range(i, end): for row in range(i, end):
for col in range( for col in range(i, (row + 1 if self.attention == 'unidirectional' else end)):
i,
(row + 1 if self.attention == 'unidirectional' else end)):
layout[h, row, col] = 1 layout[h, row, col] = 1
return layout return layout
...@@ -206,8 +203,7 @@ class FixedSparsityConfig(SparsityConfig): ...@@ -206,8 +203,7 @@ class FixedSparsityConfig(SparsityConfig):
# set last global blocks; handle possible short last local window # set last global blocks; handle possible short last local window
if (end < num_blocks): if (end < num_blocks):
start = min(end + first_global_block_idx, start = min(end + first_global_block_idx, num_blocks - self.num_global_blocks)
num_blocks - self.num_global_blocks)
end = start + self.num_global_blocks end = start + self.num_global_blocks
# vertical global attention # vertical global attention
...@@ -250,6 +246,7 @@ class VariableSparsityConfig(SparsityConfig): ...@@ -250,6 +246,7 @@ class VariableSparsityConfig(SparsityConfig):
For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. For more details about `Fixed` sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
""" """
def __init__(self, def __init__(self,
num_heads, num_heads,
block=16, block=16,
...@@ -296,14 +293,11 @@ class VariableSparsityConfig(SparsityConfig): ...@@ -296,14 +293,11 @@ class VariableSparsityConfig(SparsityConfig):
self.global_block_end_indices = global_block_end_indices self.global_block_end_indices = global_block_end_indices
if (attention != 'unidirectional' and attention != 'bidirectional'): if (attention != 'unidirectional' and attention != 'bidirectional'):
raise NotImplementedError( raise NotImplementedError('only \"uni/bi-directional\" attentions are supported for now!')
'only \"uni/bi-directional\" attentions are supported for now!')
self.attention = attention self.attention = attention
if (attention != 'bidirectional' and horizontal_global_attention): if (attention != 'bidirectional' and horizontal_global_attention):
raise ValueError( raise ValueError('only \"bi-directional\" attentions can support horizontal global attention!')
'only \"bi-directional\" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention self.horizontal_global_attention = horizontal_global_attention
def set_random_layout(self, h, layout): def set_random_layout(self, h, layout):
...@@ -345,9 +339,7 @@ class VariableSparsityConfig(SparsityConfig): ...@@ -345,9 +339,7 @@ class VariableSparsityConfig(SparsityConfig):
end_block_idx += block_size end_block_idx += block_size
end_block_idx = min(end_block_idx, num_blocks) end_block_idx = min(end_block_idx, num_blocks)
for row in range(start_block_idx, end_block_idx): for row in range(start_block_idx, end_block_idx):
for col in range( for col in range(start_block_idx, (row + 1 if self.attention == 'unidirectional' else end_block_idx)):
start_block_idx,
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
layout[h, row, col] = 1 layout[h, row, col] = 1
start_block_idx += block_size start_block_idx += block_size
...@@ -355,9 +347,7 @@ class VariableSparsityConfig(SparsityConfig): ...@@ -355,9 +347,7 @@ class VariableSparsityConfig(SparsityConfig):
for i in range(start_block_idx, num_blocks, block_size): for i in range(start_block_idx, num_blocks, block_size):
end_block_idx = min(i + block_size, num_blocks) end_block_idx = min(i + block_size, num_blocks)
for row in range(i, end_block_idx): for row in range(i, end_block_idx):
for col in range( for col in range(i, (row + 1 if self.attention == 'unidirectional' else end_block_idx)):
i,
(row + 1 if self.attention == 'unidirectional' else end_block_idx)):
layout[h, row, col] = 1 layout[h, row, col] = 1
return layout return layout
...@@ -423,6 +413,7 @@ class BigBirdSparsityConfig(SparsityConfig): ...@@ -423,6 +413,7 @@ class BigBirdSparsityConfig(SparsityConfig):
For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf For more details about this sparsity config, please see `Big Bird: Transformers for Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity. This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
""" """
def __init__(self, def __init__(self,
num_heads, num_heads,
block=16, block=16,
...@@ -452,8 +443,7 @@ class BigBirdSparsityConfig(SparsityConfig): ...@@ -452,8 +443,7 @@ class BigBirdSparsityConfig(SparsityConfig):
self.num_global_blocks = num_global_blocks self.num_global_blocks = num_global_blocks
if (attention != 'unidirectional' and attention != 'bidirectional'): if (attention != 'unidirectional' and attention != 'bidirectional'):
raise NotImplementedError( raise NotImplementedError('only \"uni/bi-directional\" attentions are supported for now!')
'only \"uni/bi-directional\" attentions are supported for now!')
self.attention = attention self.attention = attention
def set_random_layout(self, h, layout): def set_random_layout(self, h, layout):
...@@ -475,10 +465,7 @@ class BigBirdSparsityConfig(SparsityConfig): ...@@ -475,10 +465,7 @@ class BigBirdSparsityConfig(SparsityConfig):
) )
for row in range(0, num_blocks): for row in range(0, num_blocks):
sample_range = range( sample_range = range(0, num_blocks) if self.attention == 'bidirectional' else range(0, row + 1)
0,
num_blocks) if self.attention == 'bidirectional' else range(0,
row + 1)
rnd_cols = random.sample(sample_range, self.num_random_blocks) rnd_cols = random.sample(sample_range, self.num_random_blocks)
layout[h, row, rnd_cols] = 1 layout[h, row, rnd_cols] = 1
return layout return layout
...@@ -564,6 +551,7 @@ class BSLongformerSparsityConfig(SparsityConfig): ...@@ -564,6 +551,7 @@ class BSLongformerSparsityConfig(SparsityConfig):
For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf For more details about this sparsity config, please see `Longformer: The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity. This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
""" """
def __init__(self, def __init__(self,
num_heads, num_heads,
block=16, block=16,
...@@ -687,11 +675,8 @@ class LocalSlidingWindowSparsityConfig(SparsityConfig): ...@@ -687,11 +675,8 @@ class LocalSlidingWindowSparsityConfig(SparsityConfig):
"""Configuration class to store `Local Sliding Window` sparsity configuration - a purely-local sliding window attention. """Configuration class to store `Local Sliding Window` sparsity configuration - a purely-local sliding window attention.
This class extends parent class of `SparsityConfig` and customizes it for `Local` sparsity. This class extends parent class of `SparsityConfig` and customizes it for `Local` sparsity.
""" """
def __init__(self,
num_heads, def __init__(self, num_heads, block=16, num_sliding_window_blocks=3, attention='unidirectional'):
block=16,
num_sliding_window_blocks=3,
attention='unidirectional'):
"""Initialize the Local Sliding Window Sparsity Pattern Config. """Initialize the Local Sliding Window Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments: Arguments:
...@@ -723,8 +708,7 @@ class LocalSlidingWindowSparsityConfig(SparsityConfig): ...@@ -723,8 +708,7 @@ class LocalSlidingWindowSparsityConfig(SparsityConfig):
w = self.num_sliding_window_blocks // 2 w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks): for row in range(0, num_blocks):
start = max(0, row - w) start = max(0, row - w)
end = min(row + w + 1, end = min(row + w + 1, num_blocks) if self.attention == "bidirectional" else row + 1
num_blocks) if self.attention == "bidirectional" else row + 1
layout[h, row, start:end] = 1 layout[h, row, start:end] = 1
return layout return layout
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import sys import sys
import os import os
......
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a // Copyright (c) Microsoft Corporation.
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py // SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
*/
__global__ void NAME (TYPE* A __readonly __noalias __aligned(16), __global__ void NAME (TYPE* A __readonly __noalias __aligned(16),
TYPE* B __readonly __noalias __aligned(16), TYPE* B __readonly __noalias __aligned(16),
......
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a // Copyright (c) Microsoft Corporation.
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py // SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
*/
__global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16), __global__ void softmax_bwd(TYPE * X __readonly __noalias __aligned(16),
float scale, float scale,
......
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a // Copyright (c) Microsoft Corporation.
// https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py // SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
https:github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/softmax.py
*/
__global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16), __global__ void softmax_fwd(TYPE *X __readonly __noalias __aligned(16),
float scale, float scale,
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .inference.config import DeepSpeedInferenceConfig from .inference.config import DeepSpeedInferenceConfig
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .config import DeepSpeedInferenceConfig from .config import DeepSpeedInferenceConfig
from ....model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference from ....model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
from typing import Optional from typing import Optional
import torch import torch
...@@ -22,7 +23,4 @@ def nhwc_bias_add(activation: torch.Tensor, ...@@ -22,7 +23,4 @@ def nhwc_bias_add(activation: torch.Tensor,
elif other_bias is None: elif other_bias is None:
return spatial_cuda_module.nhwc_bias_add_add(activation, bias, other) return spatial_cuda_module.nhwc_bias_add_add(activation, bias, other)
else: else:
return spatial_cuda_module.nhwc_bias_add_bias_add(activation, return spatial_cuda_module.nhwc_bias_add_bias_add(activation, bias, other, other_bias)
bias,
other,
other_bias)
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import json import json
from deepspeed.utils.types import ActivationFuncType from deepspeed.utils.types import ActivationFuncType
class TransformerConfig(): class TransformerConfig():
def __init__(self, hidden_size, intermediate_size, heads, num_hidden_layers): def __init__(self, hidden_size, intermediate_size, heads, num_hidden_layers):
self.layer_id = -1 self.layer_id = -1
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -40,6 +43,7 @@ class DeepSpeedInferenceConfig(TransformerConfig): ...@@ -40,6 +43,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture. bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
""" """
def __init__(self, def __init__(self,
hidden_size=-1, hidden_size=-1,
intermediate_size=-1, intermediate_size=-1,
...@@ -65,16 +69,16 @@ class DeepSpeedInferenceConfig(TransformerConfig): ...@@ -65,16 +69,16 @@ class DeepSpeedInferenceConfig(TransformerConfig):
training_mp_size=1, training_mp_size=1,
bigscience_bloom=False, bigscience_bloom=False,
max_out_tokens=1024, max_out_tokens=1024,
min_out_tokens=1,
enable_qkv_quantization=False, enable_qkv_quantization=False,
use_mup=False, use_mup=False,
scale_attn_by_inverse_layer_idx=False, scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False): return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
super(DeepSpeedInferenceConfig, super(DeepSpeedInferenceConfig,
self).__init__( self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
hidden_size, num_hidden_layers)
(intermediate_size if intermediate_size > 0 else 4 * hidden_size),
heads,
num_hidden_layers)
self.fp16 = fp16 self.fp16 = fp16
self.pre_layer_norm = pre_layer_norm self.pre_layer_norm = pre_layer_norm
self.local_rank = local_rank self.local_rank = local_rank
...@@ -96,10 +100,13 @@ class DeepSpeedInferenceConfig(TransformerConfig): ...@@ -96,10 +100,13 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.training_mp_size = training_mp_size self.training_mp_size = training_mp_size
self.bigscience_bloom = bigscience_bloom self.bigscience_bloom = bigscience_bloom
self.max_out_tokens = max_out_tokens self.max_out_tokens = max_out_tokens
self.min_out_tokens = min_out_tokens
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.enable_qkv_quantization = enable_qkv_quantization self.enable_qkv_quantization = enable_qkv_quantization
self.use_mup = use_mup self.use_mup = use_mup
self.return_single_tuple = return_single_tuple self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
class Diffusers2DTransformerConfig(): class Diffusers2DTransformerConfig():
def __init__(self, int8_quantization=False): def __init__(self, int8_quantization=False):
self.int8_quantization = int8_quantization self.int8_quantization = int8_quantization
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -30,26 +32,12 @@ def load_triton_flash_attn(): ...@@ -30,26 +32,12 @@ def load_triton_flash_attn():
class DeepSpeedDiffusersAttentionFunction(Function): class DeepSpeedDiffusersAttentionFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, input, context, input_mask, config, attn_qkvw, attn_qw, attn_kw, attn_vw, attn_qkvb,
input, num_attention_heads_per_partition, norm_factor, hidden_size_per_partition, attn_ow, attn_ob,
context, do_out_bias, score_context_func, linear_func, triton_flash_attn_kernel):
input_mask,
config,
attn_qkvw,
attn_qw,
attn_kw,
attn_vw,
attn_qkvb,
num_attention_heads_per_partition,
norm_factor,
hidden_size_per_partition,
attn_ow,
attn_ob,
do_out_bias,
score_context_func,
linear_func,
triton_flash_attn_kernel):
def _transpose_for_context(x): def _transpose_for_context(x):
x = x.permute(0, 2, 1, 3) x = x.permute(0, 2, 1, 3)
new_x_layer_shape = x.size()[:-2] + \ new_x_layer_shape = x.size()[:-2] + \
...@@ -58,8 +46,7 @@ class DeepSpeedDiffusersAttentionFunction(Function): ...@@ -58,8 +46,7 @@ class DeepSpeedDiffusersAttentionFunction(Function):
def _transpose_for_scores(x): def _transpose_for_scores(x):
attention_head_size = x.shape[-1] // num_attention_heads_per_partition attention_head_size = x.shape[-1] // num_attention_heads_per_partition
new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, attention_head_size)
attention_head_size)
x = x.reshape(*new_x_shape) x = x.reshape(*new_x_shape)
x = x.permute(0, 2, 1, 3) x = x.permute(0, 2, 1, 3)
return x.contiguous() return x.contiguous()
...@@ -71,19 +58,12 @@ class DeepSpeedDiffusersAttentionFunction(Function): ...@@ -71,19 +58,12 @@ class DeepSpeedDiffusersAttentionFunction(Function):
do_flash_attn = (head_size <= 128) do_flash_attn = (head_size <= 128)
scale = (1 / norm_factor) * (1 / norm_factor) scale = (1 / norm_factor) * (1 / norm_factor)
if do_flash_attn and context == None: if do_flash_attn and context == None:
qkv_out = linear_func(input, qkv_out = linear_func(input, attn_qkvw, attn_qkvb if attn_qkvb is not None else attn_qkvw, attn_qkvb
attn_qkvw, is not None, do_flash_attn, config.heads, False)
attn_qkvb if attn_qkvb is not None else attn_qkvw,
attn_qkvb is not None, context_layer = triton_flash_attn_kernel(qkv_out[0], qkv_out[1], qkv_out[2], scale,
do_flash_attn,
config.heads)
context_layer = triton_flash_attn_kernel(qkv_out[0],
qkv_out[1],
qkv_out[2],
scale,
input.shape[-2] % 128 == 0) input.shape[-2] % 128 == 0)
context_layer = _transpose_for_context(context_layer[:,:,:,:head_size]) context_layer = _transpose_for_context(context_layer[:, :, :, :head_size])
else: else:
do_flash_attn = False do_flash_attn = False
...@@ -97,21 +77,12 @@ class DeepSpeedDiffusersAttentionFunction(Function): ...@@ -97,21 +77,12 @@ class DeepSpeedDiffusersAttentionFunction(Function):
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
value = value.contiguous() value = value.contiguous()
query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads, do_flash_attn) query, key, value = inference_cuda_module.pad_transform_fp16(query, key, value, config.heads,
attention_scores = (torch.matmul(query, do_flash_attn)
key.transpose(-1, attention_scores = (torch.matmul(query, key.transpose(-1, -2)) * scale).softmax(dim=-1)
-2)) * context_layer = _transpose_for_context(torch.matmul(attention_scores, value))
scale).softmax(dim=-1)
context_layer = _transpose_for_context( output = linear_func(context_layer, attn_ow, attn_ob, do_out_bias, False, config.heads, False)
torch.matmul(attention_scores,
value))
output = linear_func(context_layer,
attn_ow,
attn_ob,
do_out_bias,
False,
config.heads)
return output return output
output = selfAttention_fp(input, context, input_mask) output = selfAttention_fp(input, context, input_mask)
...@@ -142,8 +113,7 @@ class DeepSpeedDiffusersAttention(nn.Module): ...@@ -142,8 +113,7 @@ class DeepSpeedDiffusersAttention(nn.Module):
self.config = config self.config = config
self.config.layer_id = DeepSpeedDiffusersAttention.layer_id self.config.layer_id = DeepSpeedDiffusersAttention.layer_id
DeepSpeedDiffusersAttention.layer_id += 1 DeepSpeedDiffusersAttention.layer_id += 1
device = get_accelerator().current_device_name( device = get_accelerator().current_device_name() if config.bigscience_bloom else 'cpu'
) if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
...@@ -176,9 +146,7 @@ class DeepSpeedDiffusersAttention(nn.Module): ...@@ -176,9 +146,7 @@ class DeepSpeedDiffusersAttention(nn.Module):
dtype=data_type, dtype=data_type,
device=device), device=device),
requires_grad=False) requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
dtype=data_type_fp,
device=device),
requires_grad=False) requires_grad=False)
out_size_per_partition = self.config.hidden_size // self.config.mp_size out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
...@@ -187,9 +155,7 @@ class DeepSpeedDiffusersAttention(nn.Module): ...@@ -187,9 +155,7 @@ class DeepSpeedDiffusersAttention(nn.Module):
device=device), device=device),
requires_grad=False) requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
dtype=data_type_fp,
device=device),
requires_grad=False) requires_grad=False)
self.do_out_bias = True self.do_out_bias = True
...@@ -200,8 +166,7 @@ class DeepSpeedDiffusersAttention(nn.Module): ...@@ -200,8 +166,7 @@ class DeepSpeedDiffusersAttention(nn.Module):
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
self.norm_factor = math.sqrt( self.norm_factor = math.sqrt(math.sqrt(self.config.hidden_size // self.config.heads))
math.sqrt(self.config.hidden_size // self.config.heads))
if self.config.scale_attn_by_inverse_layer_idx is True: if self.config.scale_attn_by_inverse_layer_idx is True:
self.norm_factor *= math.sqrt(self.config.layer_id + 1) self.norm_factor *= math.sqrt(self.config.layer_id + 1)
...@@ -216,33 +181,15 @@ class DeepSpeedDiffusersAttention(nn.Module): ...@@ -216,33 +181,15 @@ class DeepSpeedDiffusersAttention(nn.Module):
def forward(self, input, context=None, input_mask=None): def forward(self, input, context=None, input_mask=None):
if self.config.layer_id == 0: if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size, self.allocate_workspace(self.config.hidden_size, self.config.heads,
self.config.heads,
input.size()[1], input.size()[1],
input.size()[0], input.size()[0], DeepSpeedDiffusersAttention.layer_id, self.config.mp_size, False,
DeepSpeedDiffusersAttention.layer_id, 0, self.config.max_out_tokens, self.config.min_out_tokens)
self.config.mp_size, output = DeepSpeedDiffusersAttentionFunction.apply(input, context, input_mask, self.config, self.attn_qkvw,
False, self.attn_qw, self.attn_kw, self.attn_vw, self.attn_qkvb,
0, self.num_attention_heads_per_partition, self.norm_factor,
self.config.max_out_tokens) self.hidden_size_per_partition, self.attn_ow, self.attn_ob,
output = DeepSpeedDiffusersAttentionFunction.apply( self.do_out_bias, self.score_context_func, self.linear_func,
input, self.triton_flash_attn_kernel)
context,
input_mask,
self.config,
self.attn_qkvw,
self.attn_qw,
self.attn_kw,
self.attn_vw,
self.attn_qkvb,
self.num_attention_heads_per_partition,
self.norm_factor,
self.hidden_size_per_partition,
self.attn_ow,
self.attn_ob,
self.do_out_bias,
self.score_context_func,
self.linear_func,
self.triton_flash_attn_kernel)
return output return output
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,41 +32,30 @@ def load_spatial_module(): ...@@ -31,41 +32,30 @@ def load_spatial_module():
class DeepSpeedDiffusersTransformerBlock(nn.Module): class DeepSpeedDiffusersTransformerBlock(nn.Module):
def __init__(self,
equivalent_module: nn.Module, def __init__(self, equivalent_module: nn.Module, config: Diffusers2DTransformerConfig):
config: Diffusers2DTransformerConfig):
super(DeepSpeedDiffusersTransformerBlock, self).__init__() super(DeepSpeedDiffusersTransformerBlock, self).__init__()
self.quantizer = module_inject.GroupQuantizer(q_int8=config.int8_quantization) self.quantizer = module_inject.GroupQuantizer(q_int8=config.int8_quantization)
# Ensure ops are built by the time we start running # Ensure ops are built by the time we start running
self.config = config self.config = config
self.ff1_w = self.quantizer.quantize( self.ff1_w = self.quantizer.quantize(
nn.Parameter(equivalent_module.ff.net[0].proj.weight.data, nn.Parameter(equivalent_module.ff.net[0].proj.weight.data, requires_grad=False))
requires_grad=False)) self.ff1_b = nn.Parameter(equivalent_module.ff.net[0].proj.bias.data, requires_grad=False)
self.ff1_b = nn.Parameter(equivalent_module.ff.net[0].proj.bias.data, self.ff2_w = self.quantizer.quantize(nn.Parameter(equivalent_module.ff.net[2].weight.data,
requires_grad=False) requires_grad=False))
self.ff2_w = self.quantizer.quantize( self.ff2_b = nn.Parameter(equivalent_module.ff.net[2].bias.data, requires_grad=False)
nn.Parameter(equivalent_module.ff.net[2].weight.data,
requires_grad=False)) self.norm1_g = nn.Parameter(equivalent_module.norm1.weight.data, requires_grad=False)
self.ff2_b = nn.Parameter(equivalent_module.ff.net[2].bias.data, self.norm1_b = nn.Parameter(equivalent_module.norm1.bias.data, requires_grad=False)
requires_grad=False)
self.norm1_g = nn.Parameter(equivalent_module.norm1.weight.data,
requires_grad=False)
self.norm1_b = nn.Parameter(equivalent_module.norm1.bias.data,
requires_grad=False)
self.norm1_eps = equivalent_module.norm1.eps self.norm1_eps = equivalent_module.norm1.eps
self.norm2_g = nn.Parameter(equivalent_module.norm2.weight.data, self.norm2_g = nn.Parameter(equivalent_module.norm2.weight.data, requires_grad=False)
requires_grad=False) self.norm2_b = nn.Parameter(equivalent_module.norm2.bias.data, requires_grad=False)
self.norm2_b = nn.Parameter(equivalent_module.norm2.bias.data,
requires_grad=False)
self.norm2_eps = equivalent_module.norm2.eps self.norm2_eps = equivalent_module.norm2.eps
self.norm3_g = nn.Parameter(equivalent_module.norm3.weight.data, self.norm3_g = nn.Parameter(equivalent_module.norm3.weight.data, requires_grad=False)
requires_grad=False) self.norm3_b = nn.Parameter(equivalent_module.norm3.bias.data, requires_grad=False)
self.norm3_b = nn.Parameter(equivalent_module.norm3.bias.data,
requires_grad=False)
self.norm3_eps = equivalent_module.norm3.eps self.norm3_eps = equivalent_module.norm3.eps
self.attn_1 = equivalent_module.attn1 self.attn_1 = equivalent_module.attn1
...@@ -76,16 +66,14 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module): ...@@ -76,16 +66,14 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module):
self.attn_1.do_out_bias = False self.attn_1.do_out_bias = False
self.attn_1_bias = self.attn_1.attn_ob self.attn_1_bias = self.attn_1.attn_ob
else: else:
self.attn_1_bias = nn.Parameter(torch.zeros_like(self.norm2_g), self.attn_1_bias = nn.Parameter(torch.zeros_like(self.norm2_g), requires_grad=False)
requires_grad=False)
# Pull the bias in if we can # Pull the bias in if we can
if isinstance(self.attn_2, DeepSpeedDiffusersAttention): if isinstance(self.attn_2, DeepSpeedDiffusersAttention):
self.attn_2.do_out_bias = False self.attn_2.do_out_bias = False
self.attn_2_bias = self.attn_2.attn_ob self.attn_2_bias = self.attn_2.attn_ob
else: else:
self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), self.attn_2_bias = nn.Paramaeter(torch.zeros_like(self.norm3_g), requires_grad=False)
requires_grad=False)
self.transformer_cuda_module = load_transformer_module() self.transformer_cuda_module = load_transformer_module()
load_spatial_module() load_spatial_module()
...@@ -99,25 +87,14 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module): ...@@ -99,25 +87,14 @@ class DeepSpeedDiffusersTransformerBlock(nn.Module):
if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] != None: if "encoder_hidden_states" in kwargs and kwargs["encoder_hidden_states"] != None:
context = kwargs["encoder_hidden_states"] context = kwargs["encoder_hidden_states"]
out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, out_norm_1 = self.transformer_cuda_module.layer_norm(hidden_states, self.norm1_g, self.norm1_b, self.norm1_eps)
self.norm1_g,
self.norm1_b,
self.norm1_eps)
out_attn_1 = self.attn_1(out_norm_1) out_attn_1 = self.attn_1(out_norm_1)
out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(out_attn_1, out_norm_2, out_attn_1 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
self.attn_1_bias, out_attn_1, self.attn_1_bias, hidden_states, self.norm2_g, self.norm2_b, self.norm2_eps)
hidden_states,
self.norm2_g,
self.norm2_b,
self.norm2_eps)
out_attn_2 = self.attn_2(out_norm_2, context=context) out_attn_2 = self.attn_2(out_norm_2, context=context)
out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(out_attn_2, out_norm_3, out_attn_2 = self.transformer_cuda_module.layer_norm_residual_store_pre_ln_res(
self.attn_2_bias, out_attn_2, self.attn_2_bias, out_attn_1, self.norm3_g, self.norm3_b, self.norm3_eps)
out_attn_1,
self.norm3_g,
self.norm3_b,
self.norm3_eps)
out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w) out_ff1 = nn.functional.linear(out_norm_3, self.ff1_w)
out_geglu = self.transformer_cuda_module.bias_geglu(out_ff1, self.ff1_b) out_geglu = self.transformer_cuda_module.bias_geglu(out_ff1, self.ff1_b)
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
import torch import torch
...@@ -14,6 +15,7 @@ minus_inf = -10000.0 ...@@ -14,6 +15,7 @@ minus_inf = -10000.0
class DeepSpeedSelfAttention(nn.Module): class DeepSpeedSelfAttention(nn.Module):
num_layers = 0 num_layers = 0
_qkv_buffers = []
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1): def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1):
super(DeepSpeedSelfAttention, self).__init__() super(DeepSpeedSelfAttention, self).__init__()
...@@ -22,29 +24,36 @@ class DeepSpeedSelfAttention(nn.Module): ...@@ -22,29 +24,36 @@ class DeepSpeedSelfAttention(nn.Module):
data_type_fp = torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float
self.config.layer_id = DeepSpeedSelfAttention.num_layers self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = get_accelerator().current_device_name( device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
) #if config.bigscience_bloom else 'cpu' if self.config.set_empty_params:
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qw = None
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, self.attn_qb = None
qkv_size_per_partition, self.attn_kw = None
dtype=data_type, self.attn_kb = None
device=device), self.attn_vw = None
requires_grad=False) self.attn_vb = None
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, self.attn_qkvw = None
dtype=data_type_fp, self.attn_qkvb = None
device=device), self.attn_ow = None
requires_grad=False) self.attn_ob = None
out_size_per_partition = self.config.hidden_size // self.config.mp_size else:
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.config.hidden_size, self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type, qkv_size_per_partition,
device=device), dtype=data_type,
requires_grad=False) device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
dtype=data_type_fp, requires_grad=False)
device=device), out_size_per_partition = self.config.hidden_size // self.config.mp_size
requires_grad=False) self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
...@@ -69,6 +78,14 @@ class DeepSpeedSelfAttention(nn.Module): ...@@ -69,6 +78,14 @@ class DeepSpeedSelfAttention(nn.Module):
self.score_context_func = SoftmaxContextOp(config) self.score_context_func = SoftmaxContextOp(config)
self.linear_func = LinearOp(config) self.linear_func = LinearOp(config)
self.vector_matmul_func = VectorMatMulOp(config) self.vector_matmul_func = VectorMatMulOp(config)
if len(DeepSpeedSelfAttention._qkv_buffers) == 0:
DeepSpeedSelfAttention._qkv_buffers = [
torch.empty(self.hidden_size_per_partition * 3,
self.config.hidden_size,
dtype=data_type_fp,
device=device),
torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device)
]
def compute_attention(self, qkv_out, input_mask, layer_past, alibi): def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list): if isinstance(qkv_out, list):
...@@ -93,6 +110,18 @@ class DeepSpeedSelfAttention(nn.Module): ...@@ -93,6 +110,18 @@ class DeepSpeedSelfAttention(nn.Module):
context_layer, key_layer, value_layer = attn_key_value context_layer, key_layer, value_layer = attn_key_value
return context_layer, key_layer, value_layer return context_layer, key_layer, value_layer
def _merge_qkv(self):
qvkw = DeepSpeedSelfAttention._qkv_buffers[0]
qvkw[:self.hidden_size_per_partition, :] = self.attn_qw
qvkw[self.hidden_size_per_partition:2 * self.hidden_size_per_partition, :] = self.attn_kw
qvkw[2 * self.hidden_size_per_partition:, :] = self.attn_vw
if self.attn_qb is not None:
qvkb = DeepSpeedSelfAttention._qkv_buffers[1]
qvkb[:self.hidden_size_per_partition] = self.attn_qb
qvkb[self.hidden_size_per_partition:2 * self.hidden_size_per_partition] = self.attn_kb
qvkb[2 * self.hidden_size_per_partition:] = self.attn_vb
return DeepSpeedSelfAttention._qkv_buffers
def forward(self, def forward(self,
input, input,
input_mask, input_mask,
...@@ -105,44 +134,44 @@ class DeepSpeedSelfAttention(nn.Module): ...@@ -105,44 +134,44 @@ class DeepSpeedSelfAttention(nn.Module):
norm_w=None, norm_w=None,
norm_b=None, norm_b=None,
alibi=None): alibi=None):
if self.attn_qkvw is None:
self._attn_qkvw, self._attn_qkvb = self._merge_qkv()
else:
self._attn_qkvw = self.attn_qkvw
self._attn_qkvb = self.attn_qkvb
if not self.config.pre_layer_norm: if not self.config.pre_layer_norm:
qkv_out = self.linear_func(input=input, qkv_out = self.linear_func(input=input,
weight=self.attn_qkvw, weight=self._attn_qkvw,
bias=self.attn_qkvb, bias=self._attn_qkvb,
add_bias=self.attn_qkvb is not None, add_bias=self.attn_qkvb is not None,
do_flash_attn=False, do_flash_attn=False,
num_heads=self.num_attention_heads_per_partition, num_heads=self.num_attention_heads_per_partition,
num_layers=DeepSpeedSelfAttention.num_layers) num_layers=DeepSpeedSelfAttention.num_layers)
else: else:
qkv_out = self.qkv_func( qkv_out = self.qkv_func(input=input,
input=input, weight=self._attn_qkvw,
weight=self.attn_qkvw, bias=(self._attn_qkvb if self._attn_qkvb is not None else norm_b),
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b), gamma=norm_w,
gamma=norm_w, beta=norm_b,
beta=norm_b, add_bias=(self.attn_qkvb is not None),
add_bias=(self.attn_qkvb is not None), num_layers=DeepSpeedSelfAttention.num_layers,
num_layers=DeepSpeedSelfAttention.num_layers, num_heads=self.num_attention_heads_per_partition)
num_heads=self.num_attention_heads_per_partition) context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out,
input_mask=input_mask,
context_layer, key_layer, value_layer = self.compute_attention( layer_past=layer_past,
qkv_out=qkv_out, alibi=alibi)
input_mask=input_mask,
layer_past=layer_past,
alibi=alibi)
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
inp_norm = qkv_out[-1] inp_norm = qkv_out[-1]
if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size( if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
group=self.mp_group) > 1:
dist.all_reduce(output, group=self.mp_group) dist.all_reduce(output, group=self.mp_group)
return (output, key_layer, value_layer, context_layer, inp_norm) return (output, key_layer, value_layer, context_layer, inp_norm)
class BloomSelfAttention(DeepSpeedSelfAttention): class BloomSelfAttention(DeepSpeedSelfAttention):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BloomSelfAttention, self).__init__(*args, **kwargs) super(BloomSelfAttention, self).__init__(*args, **kwargs)
self.softmax_func = SoftmaxOp(self.config) self.softmax_func = SoftmaxOp(self.config)
...@@ -156,10 +185,7 @@ class BloomSelfAttention(DeepSpeedSelfAttention): ...@@ -156,10 +185,7 @@ class BloomSelfAttention(DeepSpeedSelfAttention):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
return x.view(*new_x_layer_shape).contiguous() return x.view(*new_x_layer_shape).contiguous()
def _split_tensor_along_last_dim(self, def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=True):
tensor,
num_partitions,
contiguous_split_chunks=True):
"""Split a tensor along its last dimension. """Split a tensor along its last dimension.
Args: Args:
...@@ -196,64 +222,43 @@ class BloomSelfAttention(DeepSpeedSelfAttention): ...@@ -196,64 +222,43 @@ class BloomSelfAttention(DeepSpeedSelfAttention):
mixed_x_layer = qkv_out mixed_x_layer = qkv_out
alibi = alibi.to(get_accelerator().current_device_name()) alibi = alibi.to(get_accelerator().current_device_name())
head_dim = self.hidden_size_per_partition // self.num_attention_heads_per_partition head_dim = self.hidden_size_per_partition // self.num_attention_heads_per_partition
new_tensor_shape = mixed_x_layer.size()[:-1] + ( new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * head_dim)
self.num_attention_heads_per_partition,
3 * head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mixed_x_layer, 3) query_layer, key_layer, value_layer = self._split_tensor_along_last_dim(mixed_x_layer, 3)
# [batch_size, head_dim, q_length, k_length] # [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0), output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
query_layer.size(2),
query_layer.size(1),
key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1, query_layer = query_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[2], -1)
2).reshape(output_size[0] * output_size[1],
output_size[2],
-1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, key_layer = key_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3],
2).reshape(output_size[0] * output_size[1], -1).transpose(-1, -2)
output_size[3], value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1)
-1).transpose(-1,
-2)
value_layer = value_layer.transpose(1,
2).reshape(output_size[0] * output_size[1],
output_size[3],
-1)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
value_layer),
dim=-2)
presents = (key_layer, value_layer) presents = (key_layer, value_layer)
# Raw attention scores. [batch_size * num_heads, q_length, k_length] # Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result = torch.matmul(query_layer, key_layer) matmul_result = torch.matmul(query_layer, key_layer)
# change view to [batch_size, num_heads, q_length, k_length] # change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(output_size[0], attention_scores = matmul_result.view(output_size[0], output_size[1], output_size[2], -1)
output_size[1],
output_size[2], offset = dist.get_rank() * self.num_attention_heads_per_partition if dist.is_initialized() else 0
-1) attention_probs = self.softmax_func(attn_scores=attention_scores,
attn_mask=((1 - input_mask).half() * minus_inf),
offset = dist.get_rank( alibi=alibi,
) * self.num_attention_heads_per_partition if dist.is_initialized() else 0 triangular=(self.config.triangular_masking
attention_probs = self.softmax_func( and (attention_scores.shape[-2] > 1)),
attn_scores=attention_scores, recompute=False,
attn_mask=((1 - input_mask).half() * minus_inf), local_attention=False,
alibi=alibi, window_size=1,
triangular=(self.config.triangular_masking async_op=False,
and (attention_scores.shape[-2] > 1)), layer_scale=1 / (self.norm_factor * self.norm_factor),
recompute=False, head_offset=offset)
local_attention=False,
window_size=1,
async_op=False,
layer_scale=1 / (self.norm_factor * self.norm_factor),
head_offset=offset)
# change view [batch_size x num_heads, q_length, k_length] # change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(*matmul_result.shape) attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
...@@ -263,10 +268,8 @@ class BloomSelfAttention(DeepSpeedSelfAttention): ...@@ -263,10 +268,8 @@ class BloomSelfAttention(DeepSpeedSelfAttention):
# change view [batch_size, num_heads, q_length, head_dim] # change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view( context_layer = context_layer.view(
context_layer.size(0) // self.num_attention_heads_per_partition, context_layer.size(0) // self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.num_attention_heads_per_partition, context_layer.size(1), context_layer.shape[-1])
context_layer.size(1),
context_layer.shape[-1])
context_layer = self._transpose_for_context(context_layer) context_layer = self._transpose_for_context(context_layer)
key_layer = presents[0] key_layer = presents[0]
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import math import math
import torch import torch
...@@ -11,46 +12,41 @@ from .op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp ...@@ -11,46 +12,41 @@ from .op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
class DeepSpeedMLP(nn.Module): class DeepSpeedMLP(nn.Module):
def __init__(self,
config, def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
mp_group=None,
q_scales=None,
q_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super(DeepSpeedMLP, self).__init__() super(DeepSpeedMLP, self).__init__()
self.config = config self.config = config
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float
device = get_accelerator().current_device_name() device = get_accelerator().current_device_name()
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, if self.config.set_empty_params:
dtype=data_type_fp, self.attn_nw = None
device=device), self.attn_nb = None
requires_grad=False) self.inter_w = None
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, self.inter_b = None
dtype=data_type_fp, self.output_w = None
device=device), self.output_b = None
requires_grad=False) else:
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, requires_grad=False)
intm_size_per_partition, self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
dtype=data_type, requires_grad=False)
device=device), intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
requires_grad=False) self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, intm_size_per_partition,
dtype=data_type_fp, dtype=data_type,
device=device), device=device),
requires_grad=False) requires_grad=False)
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
self.config.hidden_size, requires_grad=False)
dtype=data_type, self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
device=device), self.config.hidden_size,
requires_grad=False) dtype=data_type,
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, device=device),
dtype=data_type_fp, requires_grad=False)
device=device), self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False) requires_grad=False)
# used for quantization # used for quantization
self.q_scales = q_scales self.q_scales = q_scales
...@@ -79,16 +75,13 @@ class DeepSpeedMLP(nn.Module): ...@@ -79,16 +75,13 @@ class DeepSpeedMLP(nn.Module):
bias=self.inter_b, bias=self.inter_b,
gamma=self.attn_nw, gamma=self.attn_nw,
beta=self.attn_nb) beta=self.attn_nb)
residual = self.residual_add_func( residual = self.residual_add_func(hidden_state=output,
hidden_state=output, residual=residual,
residual=residual, attention_output=input,
attention_output=input, attention_bias=bias if bias is not None else self.output_b,
attention_bias=bias if bias is not None else self.output_b, final_bias=self.output_b,
final_bias=self.output_b, add_bias=bias is not None,
add_bias=bias is not None, residual_add=residual_add)
residual_add=residual_add)
if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(residual, group=self.mp_group) dist.all_reduce(residual, group=self.mp_group)
return residual return residual
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import json import json
import math import math
import torch import torch
...@@ -43,6 +45,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): ...@@ -43,6 +45,7 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation. scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
""" """
def __init__(self, def __init__(self,
hidden_size=-1, hidden_size=-1,
intermediate_size=-1, intermediate_size=-1,
...@@ -72,23 +75,10 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): ...@@ -72,23 +75,10 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
mlp_type='standard', mlp_type='standard',
scale_attn_by_inverse_layer_idx=False): scale_attn_by_inverse_layer_idx=False):
super(DeepSpeedMoEInferenceConfig, super(DeepSpeedMoEInferenceConfig,
self).__init__( self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
hidden_size, num_hidden_layers, layer_norm_eps, local_rank, mp_size, fp16, q_int8, pre_layer_norm,
(intermediate_size if intermediate_size > 0 else 4 * hidden_size), stochastic_mode, scale_attention, triangular_masking, local_attention, window_size,
heads, return_tuple)
num_hidden_layers,
layer_norm_eps,
local_rank,
mp_size,
fp16,
q_int8,
pre_layer_norm,
stochastic_mode,
scale_attention,
triangular_masking,
local_attention,
window_size,
return_tuple)
self.moe_experts = moe_experts self.moe_experts = moe_experts
self.k = k self.k = k
self.capacity_factor = capacity_factor self.capacity_factor = capacity_factor
...@@ -116,44 +106,21 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): ...@@ -116,44 +106,21 @@ class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig):
class DeepSpeedMLPFunction(Function): class DeepSpeedMLPFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, input, inter_w, inter_b, config, output_b, output_w, q_scales, q_groups, merge_count, mp_group,
input,
inter_w,
inter_b,
config,
output_b,
output_w,
q_scales,
q_groups,
merge_count,
mp_group,
async_op): async_op):
if config.q_int8: if config.q_int8:
intermediate = inference_cuda_module.fused_gemm_gelu_int8( intermediate = inference_cuda_module.fused_gemm_gelu_int8(input, inter_w, inter_b, config.epsilon,
input, q_scales[2], (q_groups * (2**merge_count)),
inter_w, config.pre_layer_norm)
inter_b, output = inference_cuda_module.vector_matmul_int8(intermediate, output_w, q_scales[3], q_groups,
config.epsilon,
q_scales[2],
(q_groups * (2**merge_count)),
config.pre_layer_norm)
output = inference_cuda_module.vector_matmul_int8(intermediate,
output_w,
q_scales[3],
q_groups,
(merge_count)) (merge_count))
else: else:
mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \ mlp_gemm_func = inference_cuda_module.fused_gemm_gelu_fp16 if config.fp16 else \
inference_cuda_module.fused_gemm_gelu_fp32 inference_cuda_module.fused_gemm_gelu_fp32
output = mlp_gemm_func(input, output = mlp_gemm_func(input, inter_w, inter_b, output_w, config.epsilon, config.pre_layer_norm, async_op)
inter_w,
inter_b,
output_w,
config.epsilon,
config.pre_layer_norm,
async_op)
if mp_group is not None and dist.get_world_size(group=mp_group) > 1: if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
dist.all_reduce(output, group=mp_group, async_op=async_op) dist.all_reduce(output, group=mp_group, async_op=async_op)
...@@ -166,24 +133,17 @@ class DeepSpeedMLPFunction(Function): ...@@ -166,24 +133,17 @@ class DeepSpeedMLPFunction(Function):
class DeepSpeedMoEMLP(nn.Module): class DeepSpeedMoEMLP(nn.Module):
def __init__(self,
config, def __init__(self, config, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False, mp_group=None):
q_scales=None,
q_groups=1,
merge_count=1,
mlp_extra_grouping=False,
mp_group=None):
super(DeepSpeedMoEMLP, self).__init__() super(DeepSpeedMoEMLP, self).__init__()
self.config = config self.config = config
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
interm_size = self.config.intermediate_size // ( interm_size = self.config.intermediate_size // (1 if mp_group is None else dist.get_world_size(group=mp_group))
1 if mp_group is None else dist.get_world_size(group=mp_group))
self.inter_w = nn.Parameter(torch.Tensor(self.config.hidden_size, interm_size)) self.inter_w = nn.Parameter(torch.Tensor(self.config.hidden_size, interm_size))
self.inter_b = nn.Parameter(torch.Tensor(interm_size)) self.inter_b = nn.Parameter(torch.Tensor(interm_size))
self.output_w = nn.Parameter(torch.Tensor((interm_size), self.output_w = nn.Parameter(torch.Tensor((interm_size), self.config.hidden_size))
self.config.hidden_size))
self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
# used for quantization # used for quantization
...@@ -193,17 +153,8 @@ class DeepSpeedMoEMLP(nn.Module): ...@@ -193,17 +153,8 @@ class DeepSpeedMoEMLP(nn.Module):
self.mp_group = mp_group self.mp_group = mp_group
def forward(self, input, async_op=False): def forward(self, input, async_op=False):
return DeepSpeedMLPFunction.apply(input, return DeepSpeedMLPFunction.apply(input, self.inter_w, self.inter_b, self.config, self.output_b, self.output_w,
self.inter_w, self.q_scales, self.q_groups, self.merge_count, self.mp_group, async_op)
self.inter_b,
self.config,
self.output_b,
self.output_w,
self.q_scales,
self.q_groups,
self.merge_count,
self.mp_group,
async_op)
class DeepSpeedMoEInference(nn.Module): class DeepSpeedMoEInference(nn.Module):
...@@ -251,11 +202,7 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -251,11 +202,7 @@ class DeepSpeedMoEInference(nn.Module):
self.config.specialized_mode = specialized_mode self.config.specialized_mode = specialized_mode
DeepSpeedMoEInference.layer_id += 1 DeepSpeedMoEInference.layer_id += 1
self.attention = DeepSpeedSelfAttention(self.config, self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
mp_group,
quantize_scales,
quantize_groups,
merge_count)
self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
...@@ -263,11 +210,7 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -263,11 +210,7 @@ class DeepSpeedMoEInference(nn.Module):
self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
if config.mlp_type == 'residual': if config.mlp_type == 'residual':
self.res_mlp = DeepSpeedMoEMLP(config, self.res_mlp = DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping,
mp_group) mp_group)
self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2)) self.res_coef = nn.Parameter(torch.Tensor(self.config.hidden_size, 2))
self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \ self.coef_func = inference_cuda_module.softmax_fp16 if self.config.fp16 or self.config.q_int8 else \
...@@ -277,21 +220,12 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -277,21 +220,12 @@ class DeepSpeedMoEInference(nn.Module):
config.mp_size = 1 config.mp_size = 1
self.mlp = nn.ModuleList( self.mlp = nn.ModuleList(
DeepSpeedMoEMLP(config, DeepSpeedMoEMLP(config, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping, expert_mp_group)
quantize_scales, for i in range(self.config.moe_experts))
quantize_groups,
merge_count, self.moe_gate = TopKGate(self.config.hidden_size, self.config.global_experts, self.config.k,
mlp_extra_grouping, self.config.capacity_factor, self.config.eval_capacity_factor,
expert_mp_group) for i in range(self.config.moe_experts)) self.config.min_capacity, self.config.noisy_gate_policy, self.config.drop_tokens,
self.moe_gate = TopKGate(self.config.hidden_size,
self.config.global_experts,
self.config.k,
self.config.capacity_factor,
self.config.eval_capacity_factor,
self.config.min_capacity,
self.config.noisy_gate_policy,
self.config.drop_tokens,
self.config.use_rts) self.config.use_rts)
self.ep_group = ep_group self.ep_group = ep_group
...@@ -315,19 +249,14 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -315,19 +249,14 @@ class DeepSpeedMoEInference(nn.Module):
_, combined_weights, dispatch_mask, _ = self.moe_gate( _, combined_weights, dispatch_mask, _ = self.moe_gate(
attention_output.view(-1, self.config.hidden_size), attention_output.view(-1, self.config.hidden_size),
None, None,
) )
dispatched_attention = self.einsum_sec_sm_ecm( dispatched_attention = self.einsum_sec_sm_ecm(dispatch_mask.type_as(attention_output),
dispatch_mask.type_as(attention_output), attention_output.view(-1, self.config.hidden_size))
attention_output.view(-1,
self.config.hidden_size))
return dispatched_attention, combined_weights return dispatched_attention, combined_weights
def expert_exec(self, dispatched_input): def expert_exec(self, dispatched_input):
dispatched_input = dispatched_input.reshape( dispatched_input = dispatched_input.reshape(self.config.global_experts // self.config.moe_experts,
self.config.global_experts // self.config.moe_experts, self.config.moe_experts, -1, self.config.hidden_size)
self.config.moe_experts,
-1,
self.config.hidden_size)
chunks = dispatched_input.chunk(self.config.moe_experts, dim=1) chunks = dispatched_input.chunk(self.config.moe_experts, dim=1)
expert_outputs = torch.empty(( expert_outputs = torch.empty((
...@@ -337,29 +266,22 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -337,29 +266,22 @@ class DeepSpeedMoEInference(nn.Module):
dtype=dispatched_input.dtype, dtype=dispatched_input.dtype,
device=dispatched_input.device) device=dispatched_input.device)
for chunk, expert in zip(chunks, range(len(self.mlp))): for chunk, expert in zip(chunks, range(len(self.mlp))):
expert_outputs[expert] = self.mlp[expert](chunk.view( expert_outputs[expert] = self.mlp[expert](chunk.view(-1, dispatched_input.shape[-2],
-1, dispatched_input.shape[-1]))
dispatched_input.shape[-2],
dispatched_input.shape[-1]))
return expert_outputs return expert_outputs
def _alltoall(self, dispatched_attention): def _alltoall(self, dispatched_attention):
if dist.get_world_size(group=self.ep_group) > 1: if dist.get_world_size(group=self.ep_group) > 1:
dispatched_input = torch.empty_like(dispatched_attention) dispatched_input = torch.empty_like(dispatched_attention)
dist.all_to_all_single(dispatched_input, dist.all_to_all_single(dispatched_input, dispatched_attention, group=self.ep_group)
dispatched_attention,
group=self.ep_group)
return dispatched_input return dispatched_input
else: else:
return dispatched_attention return dispatched_attention
def scale_expert_output(self, attention_output, expert_output, combined_weights): def scale_expert_output(self, attention_output, expert_output, combined_weights):
combined_output = torch.matmul( combined_output = torch.matmul(
combined_weights.type_as(attention_output).reshape( combined_weights.type_as(attention_output).reshape(combined_weights.shape[0], -1),
combined_weights.shape[0], expert_output.reshape(-1, expert_output.shape[-1]))
-1),
expert_output.reshape(-1,
expert_output.shape[-1]))
return combined_output.reshape(attention_output.shape) return combined_output.reshape(attention_output.shape)
def forward(self, def forward(self,
...@@ -385,16 +307,9 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -385,16 +307,9 @@ class DeepSpeedMoEInference(nn.Module):
input = input.half() input = input.half()
with torch.no_grad(): with torch.no_grad():
attention_output = self.attention(input, attention_output = self.attention(input, input_mask, head_mask, layer_past, get_present,
input_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
head_mask, self.norm_w, self.norm_b)
layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
self.norm_w,
self.norm_b)
if get_present: if get_present:
attention_output, p_key, p_value = attention_output[0:3] attention_output, p_key, p_value = attention_output[0:3]
...@@ -405,10 +320,7 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -405,10 +320,7 @@ class DeepSpeedMoEInference(nn.Module):
attention_output = attention_output[0] attention_output = attention_output[0]
residual_add = attention_output + self.attention.attn_ob residual_add = attention_output + self.attention.attn_ob
attention_output = self.ds_layernorm(residual_add, attention_output = self.ds_layernorm(residual_add, self.attn_nw, self.attn_nb, self.config.epsilon)
self.attn_nw,
self.attn_nb,
self.config.epsilon)
if self.config.mlp_type == 'residual': if self.config.mlp_type == 'residual':
res_mlp_out = self.res_mlp(attention_output, async_op=True) res_mlp_out = self.res_mlp(attention_output, async_op=True)
...@@ -416,13 +328,10 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -416,13 +328,10 @@ class DeepSpeedMoEInference(nn.Module):
if self.expert_mp_group is not None: if self.expert_mp_group is not None:
tensor_list = [ tensor_list = [
torch.empty_like(attention_output) torch.empty_like(attention_output) for _ in range(dist.get_world_size(group=self.expert_mp_group))
for _ in range(dist.get_world_size(group=self.expert_mp_group))
] ]
tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output tensor_list[dist.get_rank(group=self.expert_mp_group)] = attention_output
dist.all_gather(tensor_list, dist.all_gather(tensor_list, attention_output, group=self.expert_mp_group)
attention_output,
group=self.expert_mp_group)
attention_output = torch.cat(tensor_list).contiguous() attention_output = torch.cat(tensor_list).contiguous()
############## MoE Gating + Experts ############### ############## MoE Gating + Experts ###############
...@@ -430,14 +339,11 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -430,14 +339,11 @@ class DeepSpeedMoEInference(nn.Module):
dispatched_input = self._alltoall(dispatched_attention) dispatched_input = self._alltoall(dispatched_attention)
expert_outputs = self.expert_exec(dispatched_input) expert_outputs = self.expert_exec(dispatched_input)
expert_output = self._alltoall(expert_outputs) expert_output = self._alltoall(expert_outputs)
output = self.scale_expert_output(attention_output, output = self.scale_expert_output(attention_output, expert_output, combined_weights)
expert_output,
combined_weights)
################################################ ################################################
if self.expert_mp_group is not None: if self.expert_mp_group is not None:
output = output.split(output.shape[0] // output = output.split(output.shape[0] // dist.get_world_size(group=self.expert_mp_group),
dist.get_world_size(group=self.expert_mp_group),
dim=0)[dist.get_rank(group=self.expert_mp_group)] dim=0)[dist.get_rank(group=self.expert_mp_group)]
if self.config.mlp_type == 'residual': if self.config.mlp_type == 'residual':
...@@ -446,10 +352,7 @@ class DeepSpeedMoEInference(nn.Module): ...@@ -446,10 +352,7 @@ class DeepSpeedMoEInference(nn.Module):
output = self.bias_residual_func(output, residual_add, torch.empty(1)) output = self.bias_residual_func(output, residual_add, torch.empty(1))
if not self.config.pre_layer_norm: if not self.config.pre_layer_norm:
output = self.ds_layernorm(output, output = self.ds_layernorm(output, self.norm_w, self.norm_b, self.config.epsilon)
self.norm_w,
self.norm_b,
self.config.epsilon)
if input_type != output.dtype: if input_type != output.dtype:
output = output.to(input_type) output = output.to(input_type)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .linear import LinearOp from .linear import LinearOp
from .vector_matmul import VectorMatMulOp from .vector_matmul import VectorMatMulOp
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch import torch
from ..config import DeepSpeedInferenceConfig from ..config import DeepSpeedInferenceConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment