Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
0fc19b46
Unverified
Commit
0fc19b46
authored
Apr 27, 2021
by
Wang Xiaolin
Committed by
GitHub
Apr 27, 2021
Browse files
[Feature]: parrots add parrots/fused_bias & upfirdn2d (#989)
parent
1738d50c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
721 additions
and
50 deletions
+721
-50
mmcv/ops/csrc/parrots/fused_bias_leakyrelu.cpp
mmcv/ops/csrc/parrots/fused_bias_leakyrelu.cpp
+26
-0
mmcv/ops/csrc/parrots/fused_bias_leakyrelu_cuda.cu
mmcv/ops/csrc/parrots/fused_bias_leakyrelu_cuda.cu
+109
-0
mmcv/ops/csrc/parrots/fused_bias_parrots.cpp
mmcv/ops/csrc/parrots/fused_bias_parrots.cpp
+40
-0
mmcv/ops/csrc/parrots/upfirdn2d.cpp
mmcv/ops/csrc/parrots/upfirdn2d.cpp
+25
-0
mmcv/ops/csrc/parrots/upfirdn2d_kernel.cu
mmcv/ops/csrc/parrots/upfirdn2d_kernel.cu
+370
-0
mmcv/ops/csrc/parrots/upfirdn2d_parrots.cpp
mmcv/ops/csrc/parrots/upfirdn2d_parrots.cpp
+46
-0
mmcv/ops/fused_bias_leakyrelu.py
mmcv/ops/fused_bias_leakyrelu.py
+24
-7
mmcv/ops/upfirdn2d.py
mmcv/ops/upfirdn2d.py
+27
-20
mmcv/utils/ext_loader.py
mmcv/utils/ext_loader.py
+1
-1
tests/test_ops/test_fused_bias_leakyrelu.py
tests/test_ops/test_fused_bias_leakyrelu.py
+22
-7
tests/test_ops/test_upfirdn2d.py
tests/test_ops/test_upfirdn2d.py
+31
-15
No files found.
mmcv/ops/csrc/parrots/fused_bias_leakyrelu.cpp
0 → 100644
View file @
0fc19b46
// Modified from
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch
::
Tensor
fused_bias_leakyrelu_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
);
#endif
torch
::
Tensor
fused_bias_leakyrelu
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
)
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA
(
input
);
CHECK_CUDA
(
bias
);
return
fused_bias_leakyrelu_op
(
input
,
bias
,
refer
,
act
,
grad
,
alpha
,
scale
);
#else
AT_ERROR
(
"Fused bias leakyrelu is not compiled with GPU support"
);
#endif
}
mmcv/ops/csrc/parrots/fused_bias_leakyrelu_cuda.cu
0 → 100644
View file @
0fc19b46
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
template
<
typename
scalar_t
>
static
__global__
void
fused_bias_act_kernel
(
scalar_t
*
out
,
const
scalar_t
*
p_x
,
const
scalar_t
*
p_b
,
const
scalar_t
*
p_ref
,
int
act
,
int
grad
,
scalar_t
alpha
,
scalar_t
scale
,
int
loop_x
,
int
size_x
,
int
step_b
,
int
size_b
,
int
use_bias
,
int
use_ref
)
{
int
xi
=
blockIdx
.
x
*
loop_x
*
blockDim
.
x
+
threadIdx
.
x
;
scalar_t
zero
=
0.0
;
for
(
int
loop_idx
=
0
;
loop_idx
<
loop_x
&&
xi
<
size_x
;
loop_idx
++
,
xi
+=
blockDim
.
x
)
{
scalar_t
x
=
p_x
[
xi
];
if
(
use_bias
)
{
x
+=
p_b
[(
xi
/
step_b
)
%
size_b
];
}
scalar_t
ref
=
use_ref
?
p_ref
[
xi
]
:
zero
;
scalar_t
y
;
// act = 1: linear layer
// act = 3: leaky relu layer
// grad = 0: direct forward path
// grad = 1: first order deviation
// grad = 2: second order deviation
switch
(
act
*
10
+
grad
)
{
default:
case
10
:
y
=
x
;
break
;
case
11
:
y
=
x
;
break
;
case
12
:
y
=
0.0
;
break
;
case
30
:
y
=
(
x
>
0.0
)
?
x
:
x
*
alpha
;
break
;
case
31
:
y
=
(
ref
>
0.0
)
?
x
:
x
*
alpha
;
break
;
case
32
:
y
=
0.0
;
break
;
}
out
[
xi
]
=
y
*
scale
;
}
}
torch
::
Tensor
fused_bias_leakyrelu_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
)
{
int
curDevice
=
-
1
;
cudaGetDevice
(
&
curDevice
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
curDevice
);
auto
x
=
input
.
contiguous
();
auto
b
=
bias
.
contiguous
();
auto
ref
=
refer
.
contiguous
();
int
use_bias
=
b
.
numel
()
?
1
:
0
;
int
use_ref
=
ref
.
numel
()
?
1
:
0
;
int
size_x
=
x
.
numel
();
int
size_b
=
b
.
numel
();
int
step_b
=
1
;
for
(
int
i
=
1
+
1
;
i
<
x
.
dim
();
i
++
)
{
step_b
*=
x
.
size
(
i
);
}
int
loop_x
=
4
;
int
block_size
=
4
*
32
;
int
grid_size
=
(
size_x
-
1
)
/
(
loop_x
*
block_size
)
+
1
;
auto
y
=
torch
::
empty_like
(
x
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
x
.
scalar_type
(),
"fused_bias_act_kernel"
,
[
&
]
{
fused_bias_act_kernel
<
scalar_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
y
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
b
.
data_ptr
<
scalar_t
>
(),
ref
.
data_ptr
<
scalar_t
>
(),
act
,
grad
,
alpha
,
scale
,
loop_x
,
size_x
,
step_b
,
size_b
,
use_bias
,
use_ref
);
});
return
y
;
}
mmcv/ops/csrc/parrots/fused_bias_parrots.cpp
0 → 100644
View file @
0fc19b46
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using
namespace
at
;
using
namespace
parrots
;
torch
::
Tensor
fused_bias_leakyrelu
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
bias
,
const
torch
::
Tensor
&
refer
,
int
act
,
int
grad
,
float
alpha
,
float
scale
);
void
fused_bias_leakyrelu_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
act
,
grad
;
float
alpha
,
scale
;
SSAttrs
(
attr
)
.
get
<
int
>
(
"act"
,
act
)
.
get
<
int
>
(
"grad"
,
grad
)
.
get
<
float
>
(
"alpha"
,
alpha
)
.
get
<
float
>
(
"scale"
,
scale
)
.
done
();
const
auto
&
input
=
buildATensor
(
ctx
,
ins
[
0
]);
const
auto
&
bias
=
buildATensor
(
ctx
,
ins
[
1
]);
const
auto
&
refer
=
buildATensor
(
ctx
,
ins
[
2
]);
auto
out
=
fused_bias_leakyrelu
(
input
,
bias
,
refer
,
act
,
grad
,
alpha
,
scale
);
updateDArray
(
ctx
,
out
,
outs
[
0
]);
}
PARROTS_EXTENSION_REGISTER
(
fused_bias_leakyrelu
)
.
attr
(
"act"
)
.
attr
(
"grad"
)
.
attr
(
"alpha"
)
.
attr
(
"scale"
)
.
input
(
3
)
.
output
(
1
)
.
apply
(
fused_bias_leakyrelu_parrots
)
.
done
();
mmcv/ops/csrc/parrots/upfirdn2d.cpp
0 → 100644
View file @
0fc19b46
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
);
#endif
torch
::
Tensor
upfirdn2d
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
#ifdef MMCV_WITH_CUDA
CHECK_CUDA
(
input
);
CHECK_CUDA
(
kernel
);
return
upfirdn2d_op
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
);
#else
AT_ERROR
(
"UpFirDn2d is not compiled with GPU support"
);
#endif
}
mmcv/ops/csrc/parrots/upfirdn2d_kernel.cu
0 → 100644
View file @
0fc19b46
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
static
__host__
__device__
__forceinline__
int
floor_div
(
int
a
,
int
b
)
{
int
c
=
a
/
b
;
if
(
c
*
b
>
a
)
{
c
--
;
}
return
c
;
}
struct
UpFirDn2DKernelParams
{
int
up_x
;
int
up_y
;
int
down_x
;
int
down_y
;
int
pad_x0
;
int
pad_x1
;
int
pad_y0
;
int
pad_y1
;
int
major_dim
;
int
in_h
;
int
in_w
;
int
minor_dim
;
int
kernel_h
;
int
kernel_w
;
int
out_h
;
int
out_w
;
int
loop_major
;
int
loop_x
;
};
template
<
typename
scalar_t
>
__global__
void
upfirdn2d_kernel_large
(
scalar_t
*
out
,
const
scalar_t
*
input
,
const
scalar_t
*
kernel
,
const
UpFirDn2DKernelParams
p
)
{
int
minor_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
out_y
=
minor_idx
/
p
.
minor_dim
;
minor_idx
-=
out_y
*
p
.
minor_dim
;
int
out_x_base
=
blockIdx
.
y
*
p
.
loop_x
*
blockDim
.
y
+
threadIdx
.
y
;
int
major_idx_base
=
blockIdx
.
z
*
p
.
loop_major
;
if
(
out_x_base
>=
p
.
out_w
||
out_y
>=
p
.
out_h
||
major_idx_base
>=
p
.
major_dim
)
{
return
;
}
int
mid_y
=
out_y
*
p
.
down_y
+
p
.
up_y
-
1
-
p
.
pad_y0
;
int
in_y
=
min
(
max
(
floor_div
(
mid_y
,
p
.
up_y
),
0
),
p
.
in_h
);
int
h
=
min
(
max
(
floor_div
(
mid_y
+
p
.
kernel_h
,
p
.
up_y
),
0
),
p
.
in_h
)
-
in_y
;
int
kernel_y
=
mid_y
+
p
.
kernel_h
-
(
in_y
+
1
)
*
p
.
up_y
;
for
(
int
loop_major
=
0
,
major_idx
=
major_idx_base
;
loop_major
<
p
.
loop_major
&&
major_idx
<
p
.
major_dim
;
loop_major
++
,
major_idx
++
)
{
for
(
int
loop_x
=
0
,
out_x
=
out_x_base
;
loop_x
<
p
.
loop_x
&&
out_x
<
p
.
out_w
;
loop_x
++
,
out_x
+=
blockDim
.
y
)
{
int
mid_x
=
out_x
*
p
.
down_x
+
p
.
up_x
-
1
-
p
.
pad_x0
;
int
in_x
=
min
(
max
(
floor_div
(
mid_x
,
p
.
up_x
),
0
),
p
.
in_w
);
int
w
=
min
(
max
(
floor_div
(
mid_x
+
p
.
kernel_w
,
p
.
up_x
),
0
),
p
.
in_w
)
-
in_x
;
int
kernel_x
=
mid_x
+
p
.
kernel_w
-
(
in_x
+
1
)
*
p
.
up_x
;
const
scalar_t
*
x_p
=
&
input
[((
major_idx
*
p
.
in_h
+
in_y
)
*
p
.
in_w
+
in_x
)
*
p
.
minor_dim
+
minor_idx
];
const
scalar_t
*
k_p
=
&
kernel
[
kernel_y
*
p
.
kernel_w
+
kernel_x
];
int
x_px
=
p
.
minor_dim
;
int
k_px
=
-
p
.
up_x
;
int
x_py
=
p
.
in_w
*
p
.
minor_dim
;
int
k_py
=
-
p
.
up_y
*
p
.
kernel_w
;
scalar_t
v
=
0.0
f
;
for
(
int
y
=
0
;
y
<
h
;
y
++
)
{
for
(
int
x
=
0
;
x
<
w
;
x
++
)
{
v
+=
static_cast
<
scalar_t
>
(
*
x_p
)
*
static_cast
<
scalar_t
>
(
*
k_p
);
x_p
+=
x_px
;
k_p
+=
k_px
;
}
x_p
+=
x_py
-
w
*
x_px
;
k_p
+=
k_py
-
w
*
k_px
;
}
out
[((
major_idx
*
p
.
out_h
+
out_y
)
*
p
.
out_w
+
out_x
)
*
p
.
minor_dim
+
minor_idx
]
=
v
;
}
}
}
template
<
typename
scalar_t
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
kernel_h
,
int
kernel_w
,
int
tile_out_h
,
int
tile_out_w
>
__global__
void
upfirdn2d_kernel
(
scalar_t
*
out
,
const
scalar_t
*
input
,
const
scalar_t
*
kernel
,
const
UpFirDn2DKernelParams
p
)
{
const
int
tile_in_h
=
((
tile_out_h
-
1
)
*
down_y
+
kernel_h
-
1
)
/
up_y
+
1
;
const
int
tile_in_w
=
((
tile_out_w
-
1
)
*
down_x
+
kernel_w
-
1
)
/
up_x
+
1
;
__shared__
volatile
float
sk
[
kernel_h
][
kernel_w
];
__shared__
volatile
float
sx
[
tile_in_h
][
tile_in_w
];
int
minor_idx
=
blockIdx
.
x
;
int
tile_out_y
=
minor_idx
/
p
.
minor_dim
;
minor_idx
-=
tile_out_y
*
p
.
minor_dim
;
tile_out_y
*=
tile_out_h
;
int
tile_out_x_base
=
blockIdx
.
y
*
p
.
loop_x
*
tile_out_w
;
int
major_idx_base
=
blockIdx
.
z
*
p
.
loop_major
;
if
(
tile_out_x_base
>=
p
.
out_w
|
tile_out_y
>=
p
.
out_h
|
major_idx_base
>=
p
.
major_dim
)
{
return
;
}
for
(
int
tap_idx
=
threadIdx
.
x
;
tap_idx
<
kernel_h
*
kernel_w
;
tap_idx
+=
blockDim
.
x
)
{
int
ky
=
tap_idx
/
kernel_w
;
int
kx
=
tap_idx
-
ky
*
kernel_w
;
scalar_t
v
=
0.0
;
if
(
kx
<
p
.
kernel_w
&
ky
<
p
.
kernel_h
)
{
v
=
kernel
[(
p
.
kernel_h
-
1
-
ky
)
*
p
.
kernel_w
+
(
p
.
kernel_w
-
1
-
kx
)];
}
sk
[
ky
][
kx
]
=
v
;
}
for
(
int
loop_major
=
0
,
major_idx
=
major_idx_base
;
loop_major
<
p
.
loop_major
&
major_idx
<
p
.
major_dim
;
loop_major
++
,
major_idx
++
)
{
for
(
int
loop_x
=
0
,
tile_out_x
=
tile_out_x_base
;
loop_x
<
p
.
loop_x
&
tile_out_x
<
p
.
out_w
;
loop_x
++
,
tile_out_x
+=
tile_out_w
)
{
int
tile_mid_x
=
tile_out_x
*
down_x
+
up_x
-
1
-
p
.
pad_x0
;
int
tile_mid_y
=
tile_out_y
*
down_y
+
up_y
-
1
-
p
.
pad_y0
;
int
tile_in_x
=
floor_div
(
tile_mid_x
,
up_x
);
int
tile_in_y
=
floor_div
(
tile_mid_y
,
up_y
);
__syncthreads
();
for
(
int
in_idx
=
threadIdx
.
x
;
in_idx
<
tile_in_h
*
tile_in_w
;
in_idx
+=
blockDim
.
x
)
{
int
rel_in_y
=
in_idx
/
tile_in_w
;
int
rel_in_x
=
in_idx
-
rel_in_y
*
tile_in_w
;
int
in_x
=
rel_in_x
+
tile_in_x
;
int
in_y
=
rel_in_y
+
tile_in_y
;
scalar_t
v
=
0.0
;
if
(
in_x
>=
0
&
in_y
>=
0
&
in_x
<
p
.
in_w
&
in_y
<
p
.
in_h
)
{
v
=
input
[((
major_idx
*
p
.
in_h
+
in_y
)
*
p
.
in_w
+
in_x
)
*
p
.
minor_dim
+
minor_idx
];
}
sx
[
rel_in_y
][
rel_in_x
]
=
v
;
}
__syncthreads
();
for
(
int
out_idx
=
threadIdx
.
x
;
out_idx
<
tile_out_h
*
tile_out_w
;
out_idx
+=
blockDim
.
x
)
{
int
rel_out_y
=
out_idx
/
tile_out_w
;
int
rel_out_x
=
out_idx
-
rel_out_y
*
tile_out_w
;
int
out_x
=
rel_out_x
+
tile_out_x
;
int
out_y
=
rel_out_y
+
tile_out_y
;
int
mid_x
=
tile_mid_x
+
rel_out_x
*
down_x
;
int
mid_y
=
tile_mid_y
+
rel_out_y
*
down_y
;
int
in_x
=
floor_div
(
mid_x
,
up_x
);
int
in_y
=
floor_div
(
mid_y
,
up_y
);
int
rel_in_x
=
in_x
-
tile_in_x
;
int
rel_in_y
=
in_y
-
tile_in_y
;
int
kernel_x
=
(
in_x
+
1
)
*
up_x
-
mid_x
-
1
;
int
kernel_y
=
(
in_y
+
1
)
*
up_y
-
mid_y
-
1
;
scalar_t
v
=
0.0
;
#pragma unroll
for
(
int
y
=
0
;
y
<
kernel_h
/
up_y
;
y
++
)
#pragma unroll
for
(
int
x
=
0
;
x
<
kernel_w
/
up_x
;
x
++
)
v
+=
sx
[
rel_in_y
+
y
][
rel_in_x
+
x
]
*
sk
[
kernel_y
+
y
*
up_y
][
kernel_x
+
x
*
up_x
];
if
(
out_x
<
p
.
out_w
&
out_y
<
p
.
out_h
)
{
out
[((
major_idx
*
p
.
out_h
+
out_y
)
*
p
.
out_w
+
out_x
)
*
p
.
minor_dim
+
minor_idx
]
=
v
;
}
}
}
}
}
torch
::
Tensor
upfirdn2d_op
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
)
{
int
curDevice
=
-
1
;
cudaGetDevice
(
&
curDevice
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
curDevice
);
UpFirDn2DKernelParams
p
;
auto
x
=
input
.
contiguous
();
auto
k
=
kernel
.
contiguous
();
p
.
major_dim
=
x
.
size
(
0
);
p
.
in_h
=
x
.
size
(
1
);
p
.
in_w
=
x
.
size
(
2
);
p
.
minor_dim
=
x
.
size
(
3
);
p
.
kernel_h
=
k
.
size
(
0
);
p
.
kernel_w
=
k
.
size
(
1
);
p
.
up_x
=
up_x
;
p
.
up_y
=
up_y
;
p
.
down_x
=
down_x
;
p
.
down_y
=
down_y
;
p
.
pad_x0
=
pad_x0
;
p
.
pad_x1
=
pad_x1
;
p
.
pad_y0
=
pad_y0
;
p
.
pad_y1
=
pad_y1
;
p
.
out_h
=
(
p
.
in_h
*
p
.
up_y
+
p
.
pad_y0
+
p
.
pad_y1
-
p
.
kernel_h
+
p
.
down_y
)
/
p
.
down_y
;
p
.
out_w
=
(
p
.
in_w
*
p
.
up_x
+
p
.
pad_x0
+
p
.
pad_x1
-
p
.
kernel_w
+
p
.
down_x
)
/
p
.
down_x
;
auto
out
=
at
::
empty
({
p
.
major_dim
,
p
.
out_h
,
p
.
out_w
,
p
.
minor_dim
},
x
.
options
());
int
mode
=
-
1
;
int
tile_out_h
=
-
1
;
int
tile_out_w
=
-
1
;
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
1
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
3
&&
p
.
kernel_w
<=
3
)
{
mode
=
2
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
3
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
2
&&
p
.
up_y
==
2
&&
p
.
down_x
==
1
&&
p
.
down_y
==
1
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
4
;
tile_out_h
=
16
;
tile_out_w
=
64
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
4
&&
p
.
kernel_w
<=
4
)
{
mode
=
5
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
if
(
p
.
up_x
==
1
&&
p
.
up_y
==
1
&&
p
.
down_x
==
2
&&
p
.
down_y
==
2
&&
p
.
kernel_h
<=
2
&&
p
.
kernel_w
<=
2
)
{
mode
=
6
;
tile_out_h
=
8
;
tile_out_w
=
32
;
}
dim3
block_size
;
dim3
grid_size
;
if
(
tile_out_h
>
0
&&
tile_out_w
>
0
)
{
p
.
loop_major
=
(
p
.
major_dim
-
1
)
/
16384
+
1
;
p
.
loop_x
=
1
;
block_size
=
dim3
(
32
*
8
,
1
,
1
);
grid_size
=
dim3
(((
p
.
out_h
-
1
)
/
tile_out_h
+
1
)
*
p
.
minor_dim
,
(
p
.
out_w
-
1
)
/
(
p
.
loop_x
*
tile_out_w
)
+
1
,
(
p
.
major_dim
-
1
)
/
p
.
loop_major
+
1
);
}
else
{
p
.
loop_major
=
(
p
.
major_dim
-
1
)
/
16384
+
1
;
p
.
loop_x
=
4
;
block_size
=
dim3
(
4
,
32
,
1
);
grid_size
=
dim3
((
p
.
out_h
*
p
.
minor_dim
-
1
)
/
block_size
.
x
+
1
,
(
p
.
out_w
-
1
)
/
(
p
.
loop_x
*
block_size
.
y
)
+
1
,
(
p
.
major_dim
-
1
)
/
p
.
loop_major
+
1
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
x
.
scalar_type
(),
"upfirdn2d_cuda"
,
[
&
]
{
switch
(
mode
)
{
case
1
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
4
,
4
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
2
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
1
,
1
,
3
,
3
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
3
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
4
,
4
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
4
:
upfirdn2d_kernel
<
scalar_t
,
2
,
2
,
1
,
1
,
2
,
2
,
16
,
64
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
5
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
case
6
:
upfirdn2d_kernel
<
scalar_t
,
1
,
1
,
2
,
2
,
4
,
4
,
8
,
32
>
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
break
;
default:
upfirdn2d_kernel_large
<
scalar_t
><<<
grid_size
,
block_size
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
x
.
data_ptr
<
scalar_t
>
(),
k
.
data_ptr
<
scalar_t
>
(),
p
);
}
});
return
out
;
}
mmcv/ops/csrc/parrots/upfirdn2d_parrots.cpp
0 → 100644
View file @
0fc19b46
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using
namespace
at
;
using
namespace
parrots
;
torch
::
Tensor
upfirdn2d
(
const
Tensor
&
input
,
const
Tensor
&
kernel
,
int
up_x
,
int
up_y
,
int
down_x
,
int
down_y
,
int
pad_x0
,
int
pad_x1
,
int
pad_y0
,
int
pad_y1
);
void
upfirdn2d_parrots
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
;
const
auto
&
input
=
buildATensor
(
ctx
,
ins
[
0
]);
const
auto
&
kernel
=
buildATensor
(
ctx
,
ins
[
1
]);
SSAttrs
(
attr
)
.
get
(
"up_x"
,
up_x
)
.
get
(
"up_y"
,
up_y
)
.
get
(
"down_x"
,
down_x
)
.
get
(
"down_y"
,
down_y
)
.
get
(
"pad_x0"
,
pad_x0
)
.
get
(
"pad_x1"
,
pad_x1
)
.
get
(
"pad_y0"
,
pad_y0
)
.
get
(
"pad_y1"
,
pad_y1
)
.
done
();
auto
out
=
upfirdn2d
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
);
updateDArray
(
ctx
,
out
,
outs
[
0
]);
}
PARROTS_EXTENSION_REGISTER
(
upfirdn2d
)
.
attr
(
"up_x"
)
.
attr
(
"up_y"
)
.
attr
(
"down_x"
)
.
attr
(
"down_y"
)
.
attr
(
"pad_x0"
)
.
attr
(
"pad_x1"
)
.
attr
(
"pad_y0"
)
.
attr
(
"pad_y1"
)
.
input
(
2
)
.
output
(
1
)
.
apply
(
upfirdn2d_parrots
)
.
done
();
mmcv/ops/fused_bias_leakyrelu.py
View file @
0fc19b46
...
...
@@ -25,9 +25,14 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
empty
=
grad_output
.
new_empty
(
0
)
grad_input
=
ext_module
.
fused_bias_leakyrelu
(
grad_output
,
empty
,
out
,
3
,
1
,
negative_slope
,
scale
)
grad_input
=
ext_module
.
fused_bias_leakyrelu
(
grad_output
,
empty
,
out
,
act
=
3
,
grad
=
1
,
alpha
=
negative_slope
,
scale
=
scale
)
dim
=
[
0
]
...
...
@@ -46,8 +51,13 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
# the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation.
gradgrad_out
=
ext_module
.
fused_bias_leakyrelu
(
gradgrad_input
,
gradgrad_bias
.
to
(
out
.
dtype
),
out
,
3
,
1
,
ctx
.
negative_slope
,
ctx
.
scale
)
gradgrad_input
,
gradgrad_bias
,
out
,
act
=
3
,
grad
=
1
,
alpha
=
ctx
.
negative_slope
,
scale
=
ctx
.
scale
)
return
gradgrad_out
,
None
,
None
,
None
...
...
@@ -57,8 +67,15 @@ class FusedBiasLeakyReLUFunction(Function):
@
staticmethod
def
forward
(
ctx
,
input
,
bias
,
negative_slope
,
scale
):
empty
=
input
.
new_empty
(
0
)
out
=
ext_module
.
fused_bias_leakyrelu
(
input
,
bias
,
empty
,
3
,
0
,
negative_slope
,
scale
)
out
=
ext_module
.
fused_bias_leakyrelu
(
input
,
bias
,
empty
,
act
=
3
,
grad
=
0
,
alpha
=
negative_slope
,
scale
=
scale
)
ctx
.
save_for_backward
(
out
)
ctx
.
negative_slope
=
negative_slope
ctx
.
scale
=
scale
...
...
mmcv/ops/upfirdn2d.py
View file @
0fc19b46
...
...
@@ -24,15 +24,14 @@ class UpFirDn2dBackward(Function):
grad_input
=
upfirdn2d_ext
.
upfirdn2d
(
grad_output
,
grad_kernel
,
down_x
,
down_y
,
up_x
,
up_y
,
g_pad_x0
,
g_pad_x1
,
g_pad_y0
,
g_pad_y1
,
)
up_x
=
down_x
,
up_y
=
down_y
,
down_x
=
up_x
,
down_y
=
up_y
,
pad_x0
=
g_pad_x0
,
pad_x1
=
g_pad_x1
,
pad_y0
=
g_pad_y0
,
pad_y1
=
g_pad_y1
)
grad_input
=
grad_input
.
view
(
in_size
[
0
],
in_size
[
1
],
in_size
[
2
],
in_size
[
3
])
...
...
@@ -63,15 +62,14 @@ class UpFirDn2dBackward(Function):
gradgrad_out
=
upfirdn2d_ext
.
upfirdn2d
(
gradgrad_input
,
kernel
,
ctx
.
up_x
,
ctx
.
up_y
,
ctx
.
down_x
,
ctx
.
down_y
,
ctx
.
pad_x0
,
ctx
.
pad_x1
,
ctx
.
pad_y0
,
ctx
.
pad_y1
,
)
up_x
=
ctx
.
up_x
,
up_y
=
ctx
.
up_y
,
down_x
=
ctx
.
down_x
,
down_y
=
ctx
.
down_y
,
pad_x0
=
ctx
.
pad_x0
,
pad_x1
=
ctx
.
pad_x1
,
pad_y0
=
ctx
.
pad_y0
,
pad_y1
=
ctx
.
pad_y1
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out
=
gradgrad_out
.
view
(
ctx
.
in_size
[
0
],
ctx
.
in_size
[
1
],
...
...
@@ -111,8 +109,17 @@ class UpFirDn2d(Function):
ctx
.
g_pad
=
(
g_pad_x0
,
g_pad_x1
,
g_pad_y0
,
g_pad_y1
)
out
=
upfirdn2d_ext
.
upfirdn2d
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
)
out
=
upfirdn2d_ext
.
upfirdn2d
(
input
,
kernel
,
up_x
=
up_x
,
up_y
=
up_y
,
down_x
=
down_x
,
down_y
=
down_y
,
pad_x0
=
pad_x0
,
pad_x1
=
pad_x1
,
pad_y0
=
pad_y0
,
pad_y1
=
pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out
=
out
.
view
(
-
1
,
channel
,
out_h
,
out_w
)
...
...
mmcv/utils/ext_loader.py
View file @
0fc19b46
...
...
@@ -19,7 +19,7 @@ else:
'nms'
,
'softnms'
,
'nms_match'
,
'nms_rotated'
,
'top_pool_forward'
,
'top_pool_backward'
,
'bottom_pool_forward'
,
'bottom_pool_backward'
,
'left_pool_forward'
,
'left_pool_backward'
,
'right_pool_forward'
,
'right_pool_backward'
'right_pool_backward'
,
'fused_bias_leakyrelu'
,
'upfirdn2d'
]
def
load_ext
(
name
,
funcs
):
...
...
tests/test_ops/test_fused_bias_leakyrelu.py
View file @
0fc19b46
import
pytest
import
torch
from
torch.autograd
import
gradcheck
,
gradgradcheck
_USING_PARROTS
=
True
try
:
from
parrots.autograd
import
gradcheck
except
ImportError
:
from
torch.autograd
import
gradcheck
,
gradgradcheck
_USING_PARROTS
=
False
class
TestFusedBiasLeakyReLU
(
object
):
...
...
@@ -16,13 +22,22 @@ class TestFusedBiasLeakyReLU(object):
def
test_gradient
(
self
):
from
mmcv.ops
import
FusedBiasLeakyReLU
if
_USING_PARROTS
:
gradcheck
(
FusedBiasLeakyReLU
(
2
).
cuda
(),
self
.
input_tensor
,
delta
=
1e-4
,
pt_atol
=
1e-3
)
else
:
gradcheck
(
FusedBiasLeakyReLU
(
2
).
cuda
(),
self
.
input_tensor
,
eps
=
1e-4
,
atol
=
1e-3
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
_USING_PARROTS
,
reason
=
'requires cuda'
)
def
test_gradgradient
(
self
):
from
mmcv.ops
import
FusedBiasLeakyReLU
...
...
tests/test_ops/test_upfirdn2d.py
View file @
0fc19b46
import
pytest
import
torch
from
torch.autograd
import
gradcheck
,
gradgradcheck
_USING_PARROTS
=
True
try
:
from
parrots.autograd
import
gradcheck
except
ImportError
:
from
torch.autograd
import
gradcheck
,
gradgradcheck
_USING_PARROTS
=
False
class
TestUpFirDn2d
(
object
):
...
...
@@ -25,17 +31,27 @@ class TestUpFirDn2d(object):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'requires cuda'
)
def
test_upfirdn2d
(
self
):
from
mmcv.ops
import
upfirdn2d
if
_USING_PARROTS
:
gradcheck
(
upfirdn2d
,
(
self
.
input_tensor
.
cuda
(),
self
.
kernel
.
type_as
(
self
.
input_tensor
).
cuda
(),
self
.
factor
,
1
,
self
.
pad
),
delta
=
1e-4
,
pt_atol
=
1e-3
)
else
:
gradcheck
(
upfirdn2d
,
(
self
.
input_tensor
.
cuda
(),
self
.
kernel
.
type_as
(
(
self
.
input_tensor
.
cuda
(),
self
.
kernel
.
type_as
(
self
.
input_tensor
).
cuda
(),
self
.
factor
,
1
,
self
.
pad
),
eps
=
1e-4
,
atol
=
1e-3
)
gradgradcheck
(
upfirdn2d
,
(
self
.
input_tensor
.
cuda
(),
self
.
kernel
.
type_as
(
(
self
.
input_tensor
.
cuda
(),
self
.
kernel
.
type_as
(
self
.
input_tensor
).
cuda
(),
self
.
factor
,
1
,
self
.
pad
),
eps
=
1e-4
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment