Unverified Commit 7cfed5f0 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
parent d7f8db8e
...@@ -6,8 +6,7 @@ import enum ...@@ -6,8 +6,7 @@ import enum
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
try: try:
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
...@@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax: if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
...@@ -67,7 +66,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -67,7 +66,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
# build and load kernel if not pre-built # build and load kernel if not pre-built
global scaled_masked_softmax global scaled_masked_softmax
if scaled_masked_softmax is None: if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
......
import os
from abc import abstractmethod
from typing import List
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
__all__ = ["_CudaExtension"]
# Some constants for installation checks
MIN_PYTORCH_VERSION_MAJOR = 1
MIN_PYTORCH_VERSION_MINOR = 10
class _CudaExtension(_CppExtension):
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is availabe
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME
if not CUDA_HOME:
raise AssertionError(
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
)
check_system_pytorch_cuda_match(CUDA_HOME)
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def build_jit(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME, load
set_cuda_arch_list(CUDA_HOME)
# get build dir
build_directory = _Extension.get_jit_extension_folder_path()
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
def build_aot(self) -> "CUDAExtension":
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
set_cuda_arch_list(CUDA_HOME)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
from .flash_attention_npu import FlashAttentionNpuExtension
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
try:
import flash_attention # noqa
HAS_FLASH_ATTN = True
except:
HAS_FLASH_ATTN = False
try:
import xformers # noqa
HAS_MEM_EFF_ATTN = True
except:
HAS_MEM_EFF_ATTN = False
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]
from ..base_extension import _Extension
class FlashAttentionDaoCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is availabe
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
)
def load(self):
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
)
)
from typing import Optional
import torch
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# check if the input is in allowed dtypes
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out
return flash_attention
from ..base_extension import _Extension
class FlashAttentionNpuExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
try:
import torch_npu # noqa
return True
except:
return False
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu."
)
def build_jit(self) -> None:
raise NotImplementedError(
"Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu."
)
def load(self):
import torch
from einops import rearrange
def npu_sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
):
"""
The scaled dot product attention.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output
return npu_sdpa_attention
from ..base_extension import _Extension
class FlashAttentionXformersCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is availabe
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def load(self):
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
)
from typing import Optional
import torch
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
return mem_eff_attention
from .layernorm_cuda import LayerNormCudaExtension
__all__ = ["LayerNormCudaExtension"]
\ No newline at end of file
from .builder import Builder from ..cuda_extension import _CudaExtension
from .utils import append_nvcc_threads, get_cuda_cc_flag from ..utils import append_nvcc_threads, get_cuda_cc_flag
class LayerNormBuilder(Builder): class LayerNormCudaExtension(_CudaExtension):
NAME = "layernorm"
PREBUILT_IMPORT_PATH = "colossalai._C.layernorm"
def __init__(self): def __init__(self):
super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) super().__init__(name="layernorm_cuda")
def sources_files(self): def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]]
return ret return ret
def include_dirs(self): def include_dirs(self):
ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] ret = [self.get_cuda_home_include()]
return ret return ret
def cxx_flags(self): def cxx_flags(self):
......
from .moe_cuda import MoeCudaExtension
__all__ = ['MoeCudaExtension']
\ No newline at end of file
from .builder import Builder from ..cuda_extension import _CudaExtension
from .utils import append_nvcc_threads, get_cuda_cc_flag from ..utils import append_nvcc_threads, get_cuda_cc_flag
class MOEBuilder(Builder): class MoeCudaExtension(_CudaExtension):
NAME = "moe"
PREBUILT_IMPORT_PATH = "colossalai._C.moe"
def __init__(self): def __init__(self):
super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) super().__init__(name="moe_cuda")
def include_dirs(self): def include_dirs(self):
ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
return ret return ret
def sources_files(self): def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]]
return ret return ret
def cxx_flags(self): def cxx_flags(self):
......
from .fused_optimizer_cuda import FusedOptimizerCudaExtension
__all__ = ['FusedOptimizerCudaExtension']
\ No newline at end of file
from .builder import Builder from ..cuda_extension import _CudaExtension
from .utils import get_cuda_cc_flag from ..utils import get_cuda_cc_flag
class FusedOptimBuilder(Builder): class FusedOptimizerCudaExtension(_CudaExtension):
NAME = "fused_optim"
PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim"
def __init__(self): def __init__(self):
super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) super().__init__(name="fused_optim_cuda")
def sources_files(self): def sources_files(self):
ret = [ ret = [
self.csrc_abs_path(fname) self.csrc_abs_path(fname)
for fname in [ for fname in [
"colossal_C_frontend.cpp", "cuda/colossal_C_frontend.cpp",
"multi_tensor_sgd_kernel.cu", "cuda/multi_tensor_sgd_kernel.cu",
"multi_tensor_scale_kernel.cu", "cuda/multi_tensor_scale_kernel.cu",
"multi_tensor_adam.cu", "cuda/multi_tensor_adam.cu",
"multi_tensor_l2norm_kernel.cu", "cuda/multi_tensor_l2norm_kernel.cu",
"multi_tensor_lamb.cu", "cuda/multi_tensor_lamb.cu",
] ]
] ]
return ret return ret
def include_dirs(self): def include_dirs(self):
ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] ret = [self.get_cuda_home_include()]
return ret return ret
def cxx_flags(self): def cxx_flags(self):
......
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension']
\ No newline at end of file
from .builder import Builder from ..cuda_extension import _CudaExtension
from .utils import append_nvcc_threads from ..utils import append_nvcc_threads
class ScaledMaskedSoftmaxBuilder(Builder): class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
NAME = "scaled_masked_softmax"
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax"
def __init__(self): def __init__(self):
super().__init__( super().__init__(name="scaled_masked_softmax_cuda")
name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH
)
# necessary 4 functions
def sources_files(self): def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] ret = [
self.csrc_abs_path(fname)
for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"]
]
return ret return ret
def include_dirs(self): def include_dirs(self):
return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return [self.get_cuda_home_include()]
def cxx_flags(self): def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros return ["-O3"] + self.version_dependent_macros
......
from .builder import Builder from ..cuda_extension import _CudaExtension
from .utils import append_nvcc_threads, get_cuda_cc_flag from ..utils import append_nvcc_threads, get_cuda_cc_flag
class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
NAME = "scaled_upper_triangle_masked_softmax"
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax"
def __init__(self): def __init__(self):
super().__init__( super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME,
prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH,
)
def include_dirs(self): def include_dirs(self):
return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return [self.get_cuda_home_include()]
def sources_files(self): def sources_files(self):
ret = [ ret = [
self.csrc_abs_path(fname) self.csrc_abs_path(fname)
for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] for fname in [
"cuda/scaled_upper_triang_masked_softmax.cpp",
"cuda/scaled_upper_triang_masked_softmax_cuda.cu",
]
] ]
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment