"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2c0da7803a75f0fc6e6d484e23ca283faa32d785"
Unverified Commit 5e95dcab authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`cuda kernels`] only compile them when initializing (#29133)

* only compile when needed

* fix mra as well

* fix yoso as well

* update

* rempve comment

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

* opps

* Update src/transformers/models/deta/modeling_deta.py

* nit
parent a7755d24
...@@ -17,8 +17,10 @@ ...@@ -17,8 +17,10 @@
import copy import copy
import math import math
import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -46,21 +48,42 @@ from ...pytorch_utils import meshgrid ...@@ -46,21 +48,42 @@ from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_ninja_available, logging from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Move this to not compile only when importing, this needs to happen later, like in __init__. MultiScaleDeformableAttention = None
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try: def load_cuda_kernels():
MultiScaleDeformableAttention = load_cuda_kernels() from torch.utils.cpp_extension import load
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}") global MultiScaleDeformableAttention
MultiScaleDeformableAttention = None
else: root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
MultiScaleDeformableAttention = None src_files = [
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
os.path.join("cuda", "ms_deform_attn_cuda.cu"),
]
]
MultiScaleDeformableAttention = load(
"MultiScaleDeformableAttention",
src_files,
with_cuda=True,
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
],
)
if is_vision_available(): if is_vision_available():
from transformers.image_transforms import center_to_corners_format from transformers.image_transforms import center_to_corners_format
...@@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ...@@ -590,6 +613,14 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__() super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
if config.d_model % num_heads != 0: if config.d_model % num_heads != 0:
raise ValueError( raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
......
...@@ -50,10 +50,15 @@ from .configuration_deta import DetaConfig ...@@ -50,10 +50,15 @@ from .configuration_deta import DetaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
MultiScaleDeformableAttention = None
# Copied from models.deformable_detr.load_cuda_kernels
def load_cuda_kernels(): def load_cuda_kernels():
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
global MultiScaleDeformableAttention
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta" root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deta"
src_files = [ src_files = [
root / filename root / filename
...@@ -78,22 +83,6 @@ def load_cuda_kernels(): ...@@ -78,22 +83,6 @@ def load_cuda_kernels():
], ],
) )
import MultiScaleDeformableAttention as MSDA
return MSDA
# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction # Copied from transformers.models.deformable_detr.modeling_deformable_detr.MultiScaleDeformableAttentionFunction
class MultiScaleDeformableAttentionFunction(Function): class MultiScaleDeformableAttentionFunction(Function):
...@@ -596,6 +585,14 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -596,6 +585,14 @@ class DetaMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: DetaConfig, num_heads: int, n_points: int): def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__() super().__init__()
kernel_loaded = MultiScaleDeformableAttention is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
if config.d_model % num_heads != 0: if config.d_model % num_heads != 0:
raise ValueError( raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}" f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
......
...@@ -58,9 +58,11 @@ MRA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -58,9 +58,11 @@ MRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Mra models at https://huggingface.co/models?filter=mra # See all Mra models at https://huggingface.co/models?filter=mra
] ]
mra_cuda_kernel = None
def load_cuda_kernels(): def load_cuda_kernels():
global cuda_kernel global mra_cuda_kernel
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra" src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mra"
def append_root(files): def append_root(files):
...@@ -68,26 +70,7 @@ def load_cuda_kernels(): ...@@ -68,26 +70,7 @@ def load_cuda_kernels():
src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"]) src_files = append_root(["cuda_kernel.cu", "cuda_launch.cu", "torch_extension.cpp"])
cuda_kernel = load("cuda_kernel", src_files, verbose=True) mra_cuda_kernel = load("cuda_kernel", src_files, verbose=True)
import cuda_kernel
cuda_kernel = None
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
try:
load_cuda_kernels()
except Exception as e:
logger.warning(
"Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of"
f" PyTorch and CUDA Toolkit are installed: {e}"
)
else:
pass
def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
...@@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block): ...@@ -112,7 +95,7 @@ def sparse_max(sparse_qk_prod, indices, query_num_block, key_num_block):
indices = indices.int() indices = indices.int()
indices = indices.contiguous() indices = indices.contiguous()
max_vals, max_vals_scatter = cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block) max_vals, max_vals_scatter = mra_cuda_kernel.index_max(index_vals, indices, query_num_block, key_num_block)
max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :] max_vals_scatter = max_vals_scatter.transpose(-1, -2)[:, :, None, :]
return max_vals, max_vals_scatter return max_vals, max_vals_scatter
...@@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32): ...@@ -178,7 +161,7 @@ def mm_to_sparse(dense_query, dense_key, indices, block_size=32):
indices = indices.int() indices = indices.int()
indices = indices.contiguous() indices = indices.contiguous()
return cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int()) return mra_cuda_kernel.mm_to_sparse(dense_query, dense_key, indices.int())
def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32): def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_size=32):
...@@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz ...@@ -216,7 +199,7 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz
indices = indices.contiguous() indices = indices.contiguous()
dense_key = dense_key.contiguous() dense_key = dense_key.contiguous()
dense_qk_prod = cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim) dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
return dense_qk_prod return dense_qk_prod
...@@ -393,7 +376,7 @@ def mra2_attention( ...@@ -393,7 +376,7 @@ def mra2_attention(
""" """
Use Mra to approximate self-attention. Use Mra to approximate self-attention.
""" """
if cuda_kernel is None: if mra_cuda_kernel is None:
return torch.zeros_like(query).requires_grad_() return torch.zeros_like(query).requires_grad_()
batch_size, num_head, seq_len, head_dim = query.size() batch_size, num_head, seq_len, head_dim = query.size()
...@@ -561,6 +544,13 @@ class MraSelfAttention(nn.Module): ...@@ -561,6 +544,13 @@ class MraSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
kernel_loaded = mra_cuda_kernel is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
......
...@@ -35,7 +35,14 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,14 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
logging,
)
from .configuration_yoso import YosoConfig from .configuration_yoso import YosoConfig
...@@ -49,28 +56,22 @@ YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -49,28 +56,22 @@ YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all YOSO models at https://huggingface.co/models?filter=yoso # See all YOSO models at https://huggingface.co/models?filter=yoso
] ]
lsh_cumulation = None
def load_cuda_kernels(): def load_cuda_kernels():
global lsh_cumulation global lsh_cumulation
try: from torch.utils.cpp_extension import load
from torch.utils.cpp_extension import load
def append_root(files): def append_root(files):
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso" src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
return [src_folder / file for file in files] return [src_folder / file for file in files]
src_files = append_root(
["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]
)
load("fast_lsh_cumulation", src_files, verbose=True) src_files = append_root(["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"])
import fast_lsh_cumulation as lsh_cumulation load("fast_lsh_cumulation", src_files, verbose=True)
return True import fast_lsh_cumulation as lsh_cumulation
except Exception:
lsh_cumulation = None
return False
def to_contiguous(input_tensors): def to_contiguous(input_tensors):
...@@ -305,6 +306,12 @@ class YosoSelfAttention(nn.Module): ...@@ -305,6 +306,12 @@ class YosoSelfAttention(nn.Module):
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
kernel_loaded = lsh_cumulation is not None
if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
try:
load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment