Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
bd635011
Commit
bd635011
authored
Apr 26, 2023
by
xiabo
Browse files
rocm环境适配
parent
9ba29737
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
5 deletions
+63
-5
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
+5
-0
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
+53
-5
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
+5
-0
No files found.
mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu
View file @
bd635011
...
...
@@ -289,7 +289,12 @@ torch::Tensor bias_act_op(const torch::Tensor &x, const torch::Tensor &b,
int
blockSize
=
4
*
32
;
int
gridSize
=
(
p
.
sizeX
-
1
)
/
(
p
.
loopX
*
blockSize
)
+
1
;
void
*
args
[]
=
{
&
p
};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
y
;
}
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
View file @
bd635011
...
...
@@ -672,8 +672,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifndef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#endif
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
...
...
@@ -720,9 +725,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifndef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#endif
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
...
@@ -852,9 +861,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
int
s
=
sx
+
sy
;
s
<<=
signXo
;
#ifndef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#endif
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
...
@@ -882,9 +895,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
int
s
=
sx
+
sy
;
s
<<=
signXo
;
#ifndef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
#endif
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
...
@@ -1171,8 +1188,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
}
if
((
uint32_t
)
signXb
<
p
.
swLimit
&&
(
uint32_t
)
signY
<
p
.
sShape
.
y
&&
signY
>=
minY
)
{
#ifndef MMCV_WITH_HIP
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#endif
p
.
s
[
si
]
=
s
;
// Write.
}
}
else
{
...
...
@@ -1189,8 +1211,13 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s
=
signXbit
*
2
;
v
=
InternalType
<
T
>::
clamp
(
v
,
p
.
clamp
);
}
#ifndef MMCV_WITH_HIP
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#endif
p
.
s
[
si
]
=
s
;
// Write.
}
else
{
// Just compute the value.
...
...
@@ -1411,11 +1438,17 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp.
uint32_t
m
=
(
threadIdx
.
x
&
16
)
?
0xffff0000u
:
0x0000ffffu
;
s
<<=
((
threadIdx
.
x
&
15
)
<<
1
);
// Shift into place.
#ifndef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
m
,
s
,
1
);
// Distribute.
s
|=
__shfl_xor_sync
(
m
,
s
,
2
);
s
|=
__shfl_xor_sync
(
m
,
s
,
4
);
s
|=
__shfl_xor_sync
(
m
,
s
,
8
);
#else
s
|=
__shfl_xor
(
s
,
1
);
// Distribute.
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
4
);
s
|=
__shfl_xor
(
s
,
8
);
#endif
// Write signs if leader and in p.s.
if
(
!
(
threadIdx
.
x
&
15
)
&&
x
<
p
.
sShape
.
x
)
// y is always in.
{
...
...
@@ -1839,9 +1872,13 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
}
// Launch filter setup kernel.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
// Copy kernels to constant memory.
if
(
writeSigns
&&
!
readSigns
)
AT_CUDA_CHECK
((
copy_filters
(
at
::
cuda
::
getCurrentCUDAStream
())));
...
...
@@ -1866,9 +1903,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{
p
.
blockZofs
=
zofs
;
int
subGz
=
std
::
min
(
maxSubGz
,
gz
-
zofs
);
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
}
// Done.
...
...
@@ -1983,7 +2026,12 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz
=
std
::
min
(
gz
,
gmax
);
// Launch.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
so
;
}
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
View file @
bd635011
...
...
@@ -734,7 +734,12 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel.
void
*
args
[]
=
{
&
p
};
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
return
y
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment