Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
27432c85
Commit
27432c85
authored
Apr 23, 2023
by
xiabo
Browse files
dtk2210.1 torch1.8.0
parent
b8c09f3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
1 deletion
+63
-1
mmcv/ops/csrc/carafe_cuda_kernel.cuh
mmcv/ops/csrc/carafe_cuda_kernel.cuh
+19
-0
mmcv/ops/csrc/pytorch/info.cpp
mmcv/ops/csrc/pytorch/info.cpp
+6
-0
setup.py
setup.py
+38
-1
No files found.
mmcv/ops/csrc/carafe_cuda_kernel.cuh
View file @
27432c85
...
...
@@ -7,7 +7,12 @@
#include "pytorch_cuda_helper.hpp"
#endif
#ifdef HIP_DIFF
#define WARP_SIZE 32
#else
#define WARP_SIZE 64
#endif
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
...
...
@@ -25,6 +30,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
return
index
;
}
/* TODO: move this to a common place */
#ifndef HIP_DIFF
template
<
typename
scalar_t
>
__device__
inline
scalar_t
min
(
scalar_t
a
,
scalar_t
b
)
{
return
a
<
b
?
a
:
b
;
...
...
@@ -34,19 +40,28 @@ template <typename scalar_t>
__device__
inline
scalar_t
max
(
scalar_t
a
,
scalar_t
b
)
{
return
a
>
b
?
a
:
b
;
}
#endif
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
warpReduceSum
(
scalar_t
val
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
#ifdef HIP_DIFF
val
+=
__shfl_down
(
val
,
offset
);
#else
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
offset
);
#endif
return
val
;
}
template
<
>
__device__
__forceinline__
phalf
warpReduceSum
(
phalf
val
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
#ifdef HIP_DIFF
__PHALF
(
val
)
+=
__shfl_down
(
FULL_MASK
,
val
,
offset
);
#else
__PHALF
(
val
)
+=
__shfl_down_sync
(
FULL_MASK
,
static_cast
<
__half
>
(
__PHALF
(
val
)),
offset
);
#endif
return
val
;
}
...
...
@@ -302,7 +317,11 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val
+=
top_diff
[
top_id
]
*
bottom_data
[
bottom_id
];
}
}
#ifdef HIP_DIFF
__syncthreads
();
#else
__syncwarp
();
#endif
output_val
=
warpReduceSum
(
output_val
);
if
(
lane_id
==
0
)
{
const
int
mask_id
=
...
...
mmcv/ops/csrc/pytorch/info.cpp
View file @
27432c85
...
...
@@ -3,12 +3,15 @@
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#include <cuda_runtime_api.h>
int
get_cudart_version
()
{
return
CUDART_VERSION
;
}
#endif
#endif
std
::
string
get_compiling_cuda_version
()
{
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
std
::
ostringstream
oss
;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
...
...
@@ -20,6 +23,9 @@ std::string get_compiling_cuda_version() {
};
printCudaStyleVersion
(
get_cudart_version
());
return
oss
.
str
();
#else
return
std
::
string
(
"rocm not vailable"
);
#endif
#else
return
std
::
string
(
"not available"
);
#endif
...
...
setup.py
View file @
27432c85
...
...
@@ -3,6 +3,9 @@ import os
import
re
from
pkg_resources
import
DistributionNotFound
,
get_distribution
from
setuptools
import
find_packages
,
setup
import
subprocess
from
typing
import
Optional
,
Union
from
pathlib
import
Path
EXT_TYPE
=
''
try
:
...
...
@@ -30,8 +33,30 @@ def choose_requirement(primary, secondary):
return
str
(
primary
)
def
get_sha
(
pytorch_root
:
Union
[
str
,
Path
])
->
str
:
try
:
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
pytorch_root
).
decode
(
'ascii'
).
strip
()
except
Exception
:
return
'Unknown'
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
mmcv_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
mmcv_root
,
"mmcv"
),
"version.py"
)
if
sha
!=
'Unknown'
:
if
sha
is
None
:
sha
=
get_sha
(
mmcv_root
)
version
=
'git'
+
sha
[:
7
]
if
os
.
getenv
(
'MMCV_BUILD_VERSION'
):
version_dtk
=
os
.
getenv
(
'MMCV_BUILD_VERSION'
,
""
)
version
+=
"."
+
version_dtk
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"a"
)
as
file
:
file
.
write
(
"__version__=__version__+'+{}'
\n
"
.
format
(
version
))
file
.
close
()
def
get_version
():
get_version_add
()
version_file
=
'mmcv/version.py'
with
open
(
version_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
...
...
@@ -220,7 +245,19 @@ def get_extensions():
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[]}
if
torch
.
cuda
.
is_available
()
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
is_rocm_pytorch
=
False
try
:
from
torch.utils.cpp_extension
import
ROCM_HOME
is_rocm_pytorch
=
True
if
((
torch
.
version
.
hip
is
not
None
)
and
(
ROCM_HOME
is
not
None
))
else
False
except
ImportError
:
pass
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
if
is_rocm_pytorch
or
torch
.
cuda
.
is_available
()
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
if
is_rocm_pytorch
:
define_macros
+=
[(
'HIP_DIFF'
,
None
)]
define_macros
+=
[(
'MMCV_WITH_CUDA'
,
None
)]
cuda_args
=
os
.
getenv
(
'MMCV_CUDA_ARGS'
)
extra_compile_args
[
'nvcc'
]
=
[
cuda_args
]
if
cuda_args
else
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment