Unverified Commit dd2c28a3 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by GitHub
Browse files

[npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update
parent d6df19ba
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]
import math
from typing import Optional
import torch
from .sdpa_attn import npu_sdpa_attention
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
class NPUColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
super().__init__()
try:
import torch_npu # noqa
except ImportError:
raise Exception("torch_npu is not installed.")
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: int = None,
bias: Optional[torch.Tensor] = None,
):
"""
Implement the scaled dot product attention with softmax.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
assert (
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
assert (
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
assert bias is None, "bias is not supported in npu colo attention"
causal = attn_mask_type is not None and attn_mask_type.value > 1
if HAS_NPU_TRIANGLE_ATTENTION:
from .triangle_attn import npu_triangle_attention
attn_fn = npu_triangle_attention
else:
attn_fn = npu_sdpa_attention
out = attn_fn(
query,
key,
value,
attn_mask=attn_mask,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
is_causal=causal,
)
return out
import math
import platform
from typing import Optional
import torch
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from colossalai.kernel import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer
......@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
cpu_adam = CPUAdamLoader().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
......
......@@ -6,7 +6,8 @@ import torch.distributed as dist
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
class SeqParallelUtils:
......@@ -280,21 +281,3 @@ def create_randomizer_with_offset(
Randomizer.increment_index()
return Randomizer(seed=base_seed)
def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType
if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")
return AttnMaskType, AttentionKernel
......@@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel import ColoAttention
def forward(
self: Blip2Attention,
......
......@@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.kernel import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
......
......@@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.kernel import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
......
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
......@@ -405,7 +406,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
AttnMaskType, ColoAttention = get_attention_kernel()
from colossalai.kernel import AttnMaskType, ColoAttention
llama_version = 2
try:
......@@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output)
......
......@@ -514,7 +514,7 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.kernel import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,
......
......@@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
......
......@@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.kernel import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
......
......@@ -35,7 +35,7 @@ from transformers.utils import (
replace_return_docstrings,
)
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
......
......@@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel):
class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import CPUAdamBuilder
from colossalai.kernel import CPUAdamLoader
cpu_optim = CPUAdamBuilder().load()
cpu_optim = CPUAdamLoader().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
......
......@@ -4,13 +4,11 @@ import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import AttnMaskType, ColoAttention
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment