Unverified Commit 7cfed5f0 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
parent d7f8db8e
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from ..base_extension import BaseExtension
from ..utils import print_rank_0
HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
pass
if HAS_NPU_TRIANGLE_ATTENTION:
def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks
Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)
if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer
class NpuTriangleAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_NPU_TRIANGLE_ATTENTION == False:
print_rank_0(
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
)
return HAS_NPU_TRIANGLE_ATTENTION
def load(self):
return npu_triangle_attention
import enum
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
import warnings
from typing import List
from .extensions import (
CpuAdamArmExtension,
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
)
from .extensions.base_extension import _Extension
__all__ = [
"KernelLoader",
"CPUAdamLoader",
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class LayerNormLoader(KernelLoader):
REGISTRY = [LayerNormCudaExtension]
class MoeLoader(KernelLoader):
REGISTRY = [MoeCudaExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
../../op_builder
\ No newline at end of file
......@@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
......@@ -28,7 +28,7 @@ def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
......
......@@ -11,7 +11,6 @@ from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.context.parallel_context import global_context as gpc
......@@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
......
......@@ -8,13 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.legacy.context import seed
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax
@LAYERS.register_module
......
......@@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
norm = 0.0
if len(grads) > 0:
......
......@@ -11,9 +11,9 @@ MOE_KERNEL = None
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
from colossalai.kernel.kernel_loader import MoeLoader
MOE_KERNEL = MOEBuilder().load()
MOE_KERNEL = MoeLoader().load()
class AllGather(torch.autograd.Function):
......@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor:
"""
Returns:
outputs: Tensor
......@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
d_expert, d_logits = MOE_KERNEL.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx
)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
......
import enum
import math
from collections import OrderedDict
from typing import Optional
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
from .base_kernel_loader import BaseKernelLoader
from .extensions.flash_attention import (
AttnMaskType,
CudaFlashAttnExtension,
CudaMemoryEfficentAttnExtension,
NpuSdpaAttnExtension,
NpuTriangleAttnExtension,
Repad,
SeqLenInfo,
Unpad,
)
from .extensions.utils import print_rank_0
class FlashAttentionLoader(BaseKernelLoader):
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
"""
FlashAttention Loader
options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
def __init__(self):
super().__init__(
# extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx
extension_map=OrderedDict(
cuda_flash_attn=CudaFlashAttnExtension,
cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension,
npu_sdpa_attn=NpuSdpaAttnExtension,
npu_triangle_attn=NpuTriangleAttnExtension,
),
supported_device=["cuda", "npu"],
)
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
def fetch_kernel(self, backend: str = None):
if backend is not None:
if not self._extension_map[backend]().is_available():
raise Exception(f"{backend} is not available for flash attention.")
return self._extension_map[backend]().fetch()
kernel = None
accelerator_name = get_accelerator().name
assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention."
for extension_name, extension in self._extension_map.items():
if extension_name.startswith(accelerator_name):
if extension().is_available():
kernel = extension().fetch()
break
if kernel is None:
raise Exception("No extension for flash attention is supported")
return kernel
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module):
......@@ -84,7 +106,7 @@ class ColoAttention(torch.nn.Module):
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
self.attn = FlashAttentionLoader().fetch_kernel()
self.attn = FlashAttentionLoader().load()
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
......@@ -120,8 +142,10 @@ class ColoAttention(torch.nn.Module):
if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
):
print_rank_0("flash attention is not applicable, switch to memory effcient attention")
self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn")
warnings.warn(
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
......
......@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init
from torch.nn.parameter import Parameter
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
from colossalai.kernel.kernel_loader import LayerNormLoader
try:
from colossalai._C import layer_norm
......@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
global layer_norm
if layer_norm is None:
layer_norm = LayerNormBuilder().load()
layer_norm = LayerNormLoader().load()
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
......
# This code from NVIDIA Megatron:
# with minor changes.
import enum
import torch
import torch.nn as nn
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
......@@ -3,7 +3,7 @@ from typing import Optional
import torch
from colossalai.kernel import CPUAdamLoader
from colossalai.kernel.kernel_loader import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer
......
......@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
......@@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer):
)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer
......
......@@ -72,9 +72,9 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.tensor(
......
......@@ -2,7 +2,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam
......@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
......@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
......
......@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import (
detach,
......
......@@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel import ColoAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def forward(
self: Blip2Attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment