Unverified Commit 7cfed5f0 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
parent d7f8db8e
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from ..base_extension import BaseExtension
from ..utils import print_rank_0
HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
pass
if HAS_NPU_TRIANGLE_ATTENTION:
def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks
Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)
if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer
class NpuTriangleAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_NPU_TRIANGLE_ATTENTION == False:
print_rank_0(
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
)
return HAS_NPU_TRIANGLE_ATTENTION
def load(self):
return npu_triangle_attention
import enum
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
import warnings
from typing import List
from .extensions import (
CpuAdamArmExtension,
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
)
from .extensions.base_extension import _Extension
__all__ = [
"KernelLoader",
"CPUAdamLoader",
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class LayerNormLoader(KernelLoader):
REGISTRY = [LayerNormCudaExtension]
class MoeLoader(KernelLoader):
REGISTRY = [MoeCudaExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
../../op_builder
\ No newline at end of file
...@@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup ...@@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
...@@ -28,7 +28,7 @@ def load_fused_optim(): ...@@ -28,7 +28,7 @@ def load_fused_optim():
global fused_optim global fused_optim
if fused_optim is None: if fused_optim is None:
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
......
...@@ -11,7 +11,6 @@ from torch import Tensor ...@@ -11,7 +11,6 @@ from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.context.parallel_context import global_context as gpc from colossalai.legacy.context.parallel_context import global_context as gpc
...@@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
......
...@@ -8,13 +8,12 @@ import torch.nn as nn ...@@ -8,13 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax
@LAYERS.register_module @LAYERS.register_module
......
...@@ -96,9 +96,9 @@ def _calc_l2_norm(grads): ...@@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
global fused_optim global fused_optim
if fused_optim is None: if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
......
...@@ -11,9 +11,9 @@ MOE_KERNEL = None ...@@ -11,9 +11,9 @@ MOE_KERNEL = None
def load_moe(): def load_moe():
global MOE_KERNEL global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder from colossalai.kernel.kernel_loader import MoeLoader
MOE_KERNEL = MOEBuilder().load() MOE_KERNEL = MoeLoader().load()
class AllGather(torch.autograd.Function): class AllGather(torch.autograd.Function):
...@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function): ...@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
class HierarchicalAllToAll(torch.autograd.Function): class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor:
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
""" """
Returns: Returns:
outputs: Tensor outputs: Tensor
...@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function): ...@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
if tokens_grad.dtype != torch.float32: if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32) tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, d_expert, d_logits = MOE_KERNEL.combine_backward(
mask, dest_idx) ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx
)
if d_expert.dtype != ctx.dtype: if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype) d_expert = d_expert.to(ctx.dtype)
......
import enum
import math import math
from collections import OrderedDict import warnings
from typing import Optional from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
from .base_kernel_loader import BaseKernelLoader
from .extensions.flash_attention import ( @dataclass
AttnMaskType, class SeqLenInfo:
CudaFlashAttnExtension, seqlens: Iterable[int] = None
CudaMemoryEfficentAttnExtension, indices: torch.Tensor = None
NpuSdpaAttnExtension, max_seqlen: int = None
NpuTriangleAttnExtension, cu_seqlens: torch.Tensor = None
Repad,
SeqLenInfo, @staticmethod
Unpad, def materialize(
) attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
from .extensions.utils import print_rank_0 ):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
class FlashAttentionLoader(BaseKernelLoader): seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
""" """
FlashAttention Loader Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
""" """
def __init__(self): @staticmethod
super().__init__( def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
# extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx ctx.save_for_backward(indices)
extension_map=OrderedDict( # [b, s, ...]
cuda_flash_attn=CudaFlashAttnExtension, assert tensor.ndim >= 3
cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, ctx.bsz = tensor.shape[0]
npu_sdpa_attn=NpuSdpaAttnExtension, out = rearrange(tensor, "b s ... -> (b s) ...")
npu_triangle_attn=NpuTriangleAttnExtension, ctx.shape = out.shape
), # [ntokens, ...]
supported_device=["cuda", "npu"], return out[indices]
)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
def fetch_kernel(self, backend: str = None):
if backend is not None: class Repad(torch.autograd.Function):
if not self._extension_map[backend]().is_available(): """
raise Exception(f"{backend} is not available for flash attention.") Adapted from
return self._extension_map[backend]().fetch() https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
kernel = None
accelerator_name = get_accelerator().name @staticmethod
assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention." def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
for extension_name, extension in self._extension_map.items(): ctx.save_for_backward(indices)
if extension_name.startswith(accelerator_name): # [ntokens, ...]
if extension().is_available(): tensor = tensor
kernel = extension().fetch() out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
break # [b*s, ...]
if kernel is None: out[indices] = tensor
raise Exception("No extension for flash attention is supported") return out
return kernel
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module): class ColoAttention(torch.nn.Module):
...@@ -84,7 +106,7 @@ class ColoAttention(torch.nn.Module): ...@@ -84,7 +106,7 @@ class ColoAttention(torch.nn.Module):
self.scale = 1 / math.sqrt(embed_dim // num_heads) self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout self.dropout = dropout
self.attn = FlashAttentionLoader().fetch_kernel() self.attn = FlashAttentionLoader().load()
@staticmethod @staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...@@ -120,8 +142,10 @@ class ColoAttention(torch.nn.Module): ...@@ -120,8 +142,10 @@ class ColoAttention(torch.nn.Module):
if self.attn.__name__ == "flash_attention" and ( if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None query.dtype not in [torch.float16, torch.bfloat16] or bias != None
): ):
print_rank_0("flash attention is not applicable, switch to memory effcient attention") warnings.warn(
self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1 causal = attn_mask_type is not None and attn_mask_type.value > 1
......
...@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder from colossalai.kernel.kernel_loader import LayerNormLoader
try: try:
from colossalai._C import layer_norm from colossalai._C import layer_norm
...@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
global layer_norm global layer_norm
if layer_norm is None: if layer_norm is None:
layer_norm = LayerNormBuilder().load() layer_norm = LayerNormLoader().load()
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
......
# This code from NVIDIA Megatron:
# with minor changes.
import enum
import torch
import torch.nn as nn
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from colossalai.kernel import CPUAdamLoader from colossalai.kernel.kernel_loader import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
......
...@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0 self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
...@@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer): ...@@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer):
) )
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer # Skip buffer
......
...@@ -72,9 +72,9 @@ class FusedSGD(Optimizer): ...@@ -72,9 +72,9 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor( self._dummy_overflow_buf = torch.tensor(
......
...@@ -2,7 +2,7 @@ from typing import Any, Optional ...@@ -2,7 +2,7 @@ from typing import Any, Optional
import torch import torch
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam
...@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam): ...@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir, nvme_offload_dir,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
...@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator ...@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
......
...@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator ...@@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import ( from ._utils import (
detach, detach,
......
...@@ -62,7 +62,7 @@ def forward_fn(): ...@@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel import ColoAttention from colossalai.nn.layer.colo_attention import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment