Unverified Commit dd2c28a3 authored by Xuanlei Zhao's avatar Xuanlei Zhao Committed by GitHub
Browse files

[npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update
parent d6df19ba
from .mha import NPUColoAttention
__all__ = ["NPUColoAttention"]
import math
from typing import Optional
import torch
from .sdpa_attn import npu_sdpa_attention
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION
class NPUColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
super().__init__()
try:
import torch_npu # noqa
except ImportError:
raise Exception("torch_npu is not installed.")
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: int = None,
bias: Optional[torch.Tensor] = None,
):
"""
Implement the scaled dot product attention with softmax.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
assert (
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
assert (
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
assert bias is None, "bias is not supported in npu colo attention"
causal = attn_mask_type is not None and attn_mask_type.value > 1
if HAS_NPU_TRIANGLE_ATTENTION:
from .triangle_attn import npu_triangle_attention
attn_fn = npu_triangle_attention
else:
attn_fn = npu_sdpa_attention
out = attn_fn(
query,
key,
value,
attn_mask=attn_mask,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
is_causal=causal,
)
return out
import math import math
import platform
from typing import Optional from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder from colossalai.kernel import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
...@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer): ...@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() cpu_adam = CPUAdamLoader().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
......
...@@ -6,7 +6,8 @@ import torch.distributed as dist ...@@ -6,7 +6,8 @@ import torch.distributed as dist
from torch import nn from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
class SeqParallelUtils: class SeqParallelUtils:
...@@ -280,21 +281,3 @@ def create_randomizer_with_offset( ...@@ -280,21 +281,3 @@ def create_randomizer_with_offset(
Randomizer.increment_index() Randomizer.increment_index()
return Randomizer(seed=base_seed) return Randomizer(seed=base_seed)
def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType
if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")
return AttnMaskType, AttentionKernel
...@@ -62,7 +62,7 @@ def forward_fn(): ...@@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,
......
...@@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM ...@@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention
......
...@@ -719,7 +719,7 @@ class GPT2PipelineForwards: ...@@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward(): def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size): def split_heads(tensor, num_heads, attn_head_size):
""" """
......
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS ...@@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True LATEST_VERSION = True
except ImportError: except ImportError:
LATEST_VERSION = False LATEST_VERSION = False
class LlamaPipelineForwards: class LlamaPipelineForwards:
""" """
This class serves as a micro library for forward function substitution of Llama models This class serves as a micro library for forward function substitution of Llama models
...@@ -405,7 +406,7 @@ class LlamaPipelineForwards: ...@@ -405,7 +406,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(): def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
AttnMaskType, ColoAttention = get_attention_kernel() from colossalai.kernel import AttnMaskType, ColoAttention
llama_version = 2 llama_version = 2
try: try:
...@@ -469,7 +470,12 @@ def get_llama_flash_attention_forward(): ...@@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention( attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
) )
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
...@@ -514,7 +514,7 @@ class OPTPipelineForwards: ...@@ -514,7 +514,7 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward(): def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def forward( def forward(
self: OPTAttention, self: OPTAttention,
......
...@@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag ...@@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward(): def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
......
...@@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward(): def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
......
...@@ -35,7 +35,7 @@ from transformers.utils import ( ...@@ -35,7 +35,7 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
......
...@@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel): ...@@ -90,9 +90,9 @@ class FusedAdamKernel(AdamKernel):
class CPUAdamKernel(AdamKernel): class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.kernel import CPUAdamLoader
cpu_optim = CPUAdamBuilder().load() cpu_optim = CPUAdamLoader().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
......
...@@ -4,13 +4,11 @@ import pytest ...@@ -4,13 +4,11 @@ import pytest
import torch import torch
from einops import rearrange from einops import rearrange
from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native import ColoAttention from colossalai.kernel import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32] DTYPE = [torch.float16, torch.bfloat16, torch.float32]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment