"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e9b1635478994f253b6e6f118ecfb88af5a47b94"
Unverified Commit b7b81d93 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up CUDA kernels (#23455)

parent 40ed18ae
......@@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Loading of Deformable DETR's CUDA kernels"""
import os
from pathlib import Path
def load_cuda_kernels():
from torch.utils.cpp_extension import load
root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_kernel")
root = Path(__file__).resolve().parent.parent.parent / "kernels" / "deformable_detr"
src_files = [
os.path.join(root, filename)
root / filename
for filename in [
"vision.cpp",
os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
......@@ -33,10 +33,8 @@ def load_cuda_kernels():
load(
"MultiScaleDeformableAttention",
src_files,
# verbose=True,
with_cuda=True,
extra_include_paths=[root],
# build_directory=os.path.dirname(os.path.realpath(__file__)),
extra_include_paths=[str(root)],
extra_cflags=["-DWITH_CUDA=1"],
extra_cuda_cflags=[
"-DCUDA_HAS_FP16=1",
......
......@@ -16,7 +16,7 @@
import math
import os
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
......@@ -56,8 +56,8 @@ def load_cuda_kernels():
from torch.utils.cpp_extension import load
def append_root(files):
src_folder = os.path.dirname(os.path.realpath(__file__))
return [os.path.join(src_folder, file) for file in files]
src_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "yoso"
return [src_folder / file for file in files]
src_files = append_root(
["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"]
......
......@@ -21,8 +21,8 @@ from pathlib import Path
FILES_TO_FIND = [
"kernels/rwkv/wkv_cuda.cu",
"kernels/rwkv/wkv_op.cpp",
"models/deformable_detr/custom_kernel/ms_deform_attn.h",
"models/deformable_detr/custom_kernel/cuda/ms_deform_im2col_cuda.cuh",
"kernels/deformable_detr/ms_deform_attn.h",
"kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh",
"models/graphormer/algos_graphormer.pyx",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment