Commit 310493b2 authored by mashun1's avatar mashun1
Browse files

stylegan3

parents
Pipeline #695 canceled with stages
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/HIPContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "filtered_lrelu.h"
#include <iostream>
//------------------------------------------------------------------------
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
TORCH_CHECK(fu.numel() > 0, "fu is empty");
TORCH_CHECK(fd.numel() > 0, "fd is empty");
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
// Figure out how much shared memory is available on the device.
int maxSharedBytes = 0;
int result = hipDeviceGetAttribute(&maxSharedBytes, hipDeviceAttributeMaxSharedMemoryPerBlock, x.device().index());
//std::cout << maxSharedBytes << std::endl;
int sharedKB = maxSharedBytes >> 10;
//int sharedKB = 64;
// Populate enough launch parameters to check if a CUDA kernel exists.
filtered_lrelu_kernel_params p;
p.up = up;
p.down = down;
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
if (!test_spec.exec)
{
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
}
// Input/output element size.
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
// Input sizes.
int64_t xw = (int)x.size(3);
int64_t xh = (int)x.size(2);
int64_t fut_w = (int)fu.size(-1) - 1;
int64_t fut_h = (int)fu.size(0) - 1;
int64_t fdt_w = (int)fd.size(-1) - 1;
int64_t fdt_h = (int)fd.size(0) - 1;
// Logical size of upsampled buffer.
int64_t cw = xw * up + (px0 + px1) - fut_w;
int64_t ch = xh * up + (py0 + py1) - fut_h;
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
// Compute output size and allocate.
int64_t yw = (cw - fdt_w + (down - 1)) / down;
int64_t yh = (ch - fdt_h + (down - 1)) / down;
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
// Allocate sign tensor.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
int64_t sw_active = 0; // Active width of sign tensor.
if (writeSigns)
{
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
else if (readSigns)
sw_active = s.size(3) << 2;
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
}
// Populate rest of CUDA kernel parameters.
p.x = x.data_ptr();
p.y = y.data_ptr();
p.b = b.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.fu = fu.data_ptr<float>();
p.fd = fd.data_ptr<float>();
p.pad0 = make_int2(px0, py0);
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.flip = (flip_filters) ? 1 : 0;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
p.sOfs = make_int2(sx, sy);
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
// x, y, b strides are in bytes.
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
p.bStride = sz * b.stride(0);
// fu, fd strides are in elements.
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
bool index64b = false;
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
if (s.numel() > INT_MAX) index64b = true;
// Choose CUDA kernel.
filtered_lrelu_kernel_spec spec = { 0 };
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
{
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
{
// Choose kernel based on index type, datatype and sign read/write modes.
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
}
});
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
// Launch CUDA kernel.
void* args[] = {&p};
int bx = spec.numWarps * 32;
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
int gz = p.yShape.z * p.yShape.w;
// Repeat multiple horizontal tiles in a CTA?
if (spec.xrep)
{
p.tilesXrep = spec.xrep;
p.tilesXdim = gx;
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
std::swap(gx, gy);
}
else
{
p.tilesXrep = 0;
p.tilesXdim = 0;
}
// Launch filter setup kernel.
AT_CUDA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
// Copy kernels to constant memory.
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK(hipFuncSetCacheConfig(reinterpret_cast<const void*>(spec.exec), hipFuncCachePreferShared));
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
AT_CUDA_CHECK(hipFuncSetAttribute(spec.exec, hipFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
AT_CUDA_CHECK(hipFuncSetSharedMemConfig(spec.exec, hipSharedMemBankSizeFourByte));
// Launch main kernel.
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
{
p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs);
AT_CUDA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
}
// Done.
return std::make_tuple(y, so, 0);
}
//------------------------------------------------------------------------
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
// Output signs if we don't have sign input.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
if (writeSigns)
{
int64_t sw = x.size(3);
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
}
// Initialize CUDA kernel parameters.
filtered_lrelu_act_kernel_params p;
p.x = x.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
p.sOfs = make_int2(sx, sy);
// Choose CUDA kernel.
void* func = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
{
if (writeSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
else if (readSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
else
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
});
TORCH_CHECK(func, "internal error - CUDA kernel not found");
// Launch CUDA kernel.
void* args[] = {&p};
int bx = 128; // 4 warps per block.
// Logical size of launch = writeSigns ? p.s : p.x
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
gx = (gx - 1) / bx + 1;
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
const uint32_t gmax = 65535;
gy = std::min(gy, gmax);
gz = std::min(gz, gmax);
// Launch.
AT_CUDA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
return so;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
}
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "filtered_lrelu.h"
//------------------------------------------------------------------------
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
TORCH_CHECK(fu.numel() > 0, "fu is empty");
TORCH_CHECK(fd.numel() > 0, "fd is empty");
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
// Figure out how much shared memory is available on the device.
int maxSharedBytes = 0;
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
int sharedKB = maxSharedBytes >> 10;
// Populate enough launch parameters to check if a CUDA kernel exists.
filtered_lrelu_kernel_params p;
p.up = up;
p.down = down;
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
if (!test_spec.exec)
{
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
}
// Input/output element size.
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
// Input sizes.
int64_t xw = (int)x.size(3);
int64_t xh = (int)x.size(2);
int64_t fut_w = (int)fu.size(-1) - 1;
int64_t fut_h = (int)fu.size(0) - 1;
int64_t fdt_w = (int)fd.size(-1) - 1;
int64_t fdt_h = (int)fd.size(0) - 1;
// Logical size of upsampled buffer.
int64_t cw = xw * up + (px0 + px1) - fut_w;
int64_t ch = xh * up + (py0 + py1) - fut_h;
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
// Compute output size and allocate.
int64_t yw = (cw - fdt_w + (down - 1)) / down;
int64_t yh = (ch - fdt_h + (down - 1)) / down;
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
// Allocate sign tensor.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
int64_t sw_active = 0; // Active width of sign tensor.
if (writeSigns)
{
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
else if (readSigns)
sw_active = s.size(3) << 2;
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
}
// Populate rest of CUDA kernel parameters.
p.x = x.data_ptr();
p.y = y.data_ptr();
p.b = b.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.fu = fu.data_ptr<float>();
p.fd = fd.data_ptr<float>();
p.pad0 = make_int2(px0, py0);
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.flip = (flip_filters) ? 1 : 0;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
p.sOfs = make_int2(sx, sy);
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
// x, y, b strides are in bytes.
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
p.bStride = sz * b.stride(0);
// fu, fd strides are in elements.
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
bool index64b = false;
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
if (s.numel() > INT_MAX) index64b = true;
// Choose CUDA kernel.
filtered_lrelu_kernel_spec spec = { 0 };
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
{
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
{
// Choose kernel based on index type, datatype and sign read/write modes.
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
}
});
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
// Launch CUDA kernel.
void* args[] = {&p};
int bx = spec.numWarps * 32;
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
int gz = p.yShape.z * p.yShape.w;
// Repeat multiple horizontal tiles in a CTA?
if (spec.xrep)
{
p.tilesXrep = spec.xrep;
p.tilesXdim = gx;
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
std::swap(gx, gy);
}
else
{
p.tilesXrep = 0;
p.tilesXdim = 0;
}
// Launch filter setup kernel.
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
// Copy kernels to constant memory.
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
// Launch main kernel.
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
{
p.blockZofs = zofs;
int subGz = std::min(maxSubGz, gz - zofs);
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
}
// Done.
return std::make_tuple(y, so, 0);
}
//------------------------------------------------------------------------
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
{
// Set CUDA device.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
// Validate arguments.
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
TORCH_CHECK(x.numel() > 0, "x is empty");
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
// Output signs if we don't have sign input.
torch::Tensor so;
torch::Tensor s = si;
bool readSigns = !!s.numel();
if (writeSigns)
{
int64_t sw = x.size(3);
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
}
// Validate sign tensor if in use.
if (readSigns || writeSigns)
{
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
}
// Initialize CUDA kernel parameters.
filtered_lrelu_act_kernel_params p;
p.x = x.data_ptr();
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
p.gain = gain;
p.slope = slope;
p.clamp = clamp;
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
p.sOfs = make_int2(sx, sy);
// Choose CUDA kernel.
void* func = 0;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
{
if (writeSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
else if (readSigns)
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
else
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
});
TORCH_CHECK(func, "internal error - CUDA kernel not found");
// Launch CUDA kernel.
void* args[] = {&p};
int bx = 128; // 4 warps per block.
// Logical size of launch = writeSigns ? p.s : p.x
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
gx = (gx - 1) / bx + 1;
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
const uint32_t gmax = 65535;
gy = std::min(gy, gmax);
gz = std::min(gz, gmax);
// Launch.
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
return so;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
}
//------------------------------------------------------------------------
#include "hip/hip_runtime.h"
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include "filtered_lrelu.h"
#include <cstdint>
//------------------------------------------------------------------------
// Helpers.
enum // Filter modes.
{
MODE_SUSD = 0, // Separable upsampling, separable downsampling.
MODE_FUSD = 1, // Full upsampling, separable downsampling.
MODE_SUFD = 2, // Separable upsampling, full downsampling.
MODE_FUFD = 3, // Full upsampling, full downsampling.
};
template <class T> struct InternalType;
template <> struct InternalType<double>
{
typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
__device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
};
template <> struct InternalType<float>
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
template <> struct InternalType<c10::Half>
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
((B)==2) ? ((int)((A)+1) >> 1) : \
((B)==4) ? ((int)((A)+3) >> 2) : \
(((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
{
if ((N & (N-1)) && N <= 256)
y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
else
y = i/N;
x = i - y*N;
}
// Type cast stride before reading it.
template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
{
return *reinterpret_cast<const T*>(&x);
}
//------------------------------------------------------------------------
// Filters, setup kernel, copying function.
#define MAX_FILTER_SIZE 32
// Combined up/down filter buffers so that transfer can be done with one copy.
__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
// Accessors to combined buffers to index up/down filters individually.
#define c_fu (c_fbuf)
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
#define g_fu (g_fbuf)
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
// Set up filters into global memory buffer.
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
{
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
{
int x, y;
fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
if (p.fuShape.y > 0)
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
else
g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
if (p.fdShape.y > 0)
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
else
g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
}
}
// Host function to copy filters written by setup kernel into constant buffer for main kernel.
//template <bool, bool> static hipError_t copy_filters(hipStream_t stream)
template <bool, bool> hipError_t copy_filters(hipStream_t stream)
{
void* src = 0;
hipError_t err = hipGetSymbolAddress(&src, HIP_SYMBOL(g_fbuf));
if (err) return err;
return hipMemcpyToSymbolAsync(HIP_SYMBOL(c_fbuf), src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, hipMemcpyDeviceToDevice, stream);
}
//------------------------------------------------------------------------
// Coordinate spaces:
// - Relative to input tensor: inX, inY, tileInX, tileInY
// - Relative to input tile: relInX, relInY, tileInW, tileInH
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
//
// Relationships between coordinate spaces:
// - inX = tileInX + relInX
// - inY = tileInY + relInY
// - relUpX = relInX * up + phaseInX
// - relUpY = relInY * up + phaseInY
// - relUpX = relOutX * down
// - relUpY = relOutY * down
// - outX = tileOutX + relOutX
// - outY = tileOutY + relOutY
extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
{
// Check that we don't try to support non-existing filter modes.
static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
// Static definitions.
typedef typename InternalType<T>::scalar_t scalar_t;
typedef typename InternalType<T>::vec2_t vec2_t;
typedef typename InternalType<T>::vec4_t vec4_t;
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
// Sizes of logical buffers.
const int szIn = tileInH_up * tileInW;
const int szUpX = tileInH_up * tileUpW;
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
const int szDownX = tileUpH * tileOutW;
// Sizes for shared memory arrays.
const int s_buf0_size_base =
(filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
(filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUFD) ? szIn :
-1;
const int s_buf1_size_base =
(filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
(filterMode == MODE_FUSD) ? szUpXY :
(filterMode == MODE_SUFD) ? szUpX :
(filterMode == MODE_FUFD) ? szUpXY :
-1;
// Ensure U128 alignment.
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
// Check at compile time that we don't use too much shared memory.
static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
// Declare shared memory arrays.
scalar_t* s_buf0;
scalar_t* s_buf1;
if (sharedKB <= 48)
{
// Allocate shared memory arrays here.
__shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
s_buf0 = s_buf0_st;
s_buf1 = s_buf0 + s_buf0_size;
}
else
{
// Use the dynamically allocated shared memory array.
s_buf0 = (scalar_t*)s_buf_raw;
s_buf1 = s_buf0 + s_buf0_size;
}
// Pointers to the buffers.
scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
if (filterMode == MODE_SUSD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
s_tileDownX = s_buf1;
}
else if (filterMode == MODE_FUSD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
s_tileDownX = s_buf0;
}
else if (filterMode == MODE_SUFD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
}
else if (filterMode == MODE_FUFD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
}
// Allow large grids in z direction via per-launch offset.
int channelIdx = blockIdx.z + p.blockZofs;
int batchIdx = channelIdx / p.yShape.z;
channelIdx -= batchIdx * p.yShape.z;
// Offset to output feature map. In bytes.
index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
// Sign shift amount.
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
// Inner tile loop.
#pragma unroll 1
for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
{
// Locate output tile.
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
int tileOutX = tileX * tileOutW;
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
// Locate input tile.
int tmpX = tileOutX * down - p.pad0.x;
int tmpY = tileOutY * down - p.pad0.y;
int tileInX = CEIL_DIV(tmpX, up);
int tileInY = CEIL_DIV(tmpY, up);
const int phaseInX = tileInX * up - tmpX;
const int phaseInY = tileInY * up - tmpY;
// Extra sync if input and output buffers are the same and we are not on first tile.
if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
__syncthreads();
// Load input tile & apply bias. Unrolled.
scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
int idx = threadIdx.x;
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
#pragma unroll
for (int loop = 0; loop < loopCountIN; loop++)
{
int relInX, relInY;
fast_div_mod<tileInW>(relInX, relInY, idx);
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
if (!skip)
s_tileIn[idx] = v;
idx += threadsPerBlock;
}
if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
{
// Horizontal upsampling.
__syncthreads();
if (up == 4)
{
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileIn[src0];
if (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInX == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInX == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
s_tileUpX[dst+2] = v.z;
s_tileUpX[dst+3] = v.w;
}
}
else if (up == 2)
{
bool p0 = (phaseInX == 0);
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileIn[src0];
if (p0) // (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
}
else // (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
}
}
// Vertical upsampling & nonlinearity.
__syncthreads();
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
if (up == 4)
{
minY -= 3; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
{
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInY == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInY == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
index_t si2 = si0 + p.sShape.x * 2;
index_t si3 = si0 + p.sShape.x * 3;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
// s |= __shfl_xor_sync(groupMask, s, 1);
// s |= __shfl_xor_sync(groupMask, s, 2);
s |= __shfl_xor(groupMask, s, 1);
s |= __shfl_xor(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
// s |= __shfl_xor_sync(groupMask, s, 1);
// s |= __shfl_xor_sync(groupMask, s, 2);
s |= __shfl_xor(groupMask, s, 1);
s |= __shfl_xor(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
}
else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit)
{
int ss = (signX & 3) << 1;
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[dst + 0 * tileUpW] = v.x;
if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
}
}
else if (up == 2)
{
minY -= 1; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
{
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
}
else // (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
// Combine signs.
int s = sx + sy;
s <<= signXo;
// s |= __shfl_xor_sync(groupMask, s, 1);
// s |= __shfl_xor_sync(groupMask, s, 2);
s |= __shfl_xor(groupMask, s, 1);
s |= __shfl_xor(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
// Combine signs.
int s = sx + sy;
s <<= signXo;
// s |= __shfl_xor_sync(groupMask, s, 1);
// s |= __shfl_xor_sync(groupMask, s, 2);
s |= __shfl_xor(groupMask, s, 1);
s |= __shfl_xor(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
}
}
}
else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit)
{
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (!downInline)
{
// Write into temporary buffer.
s_tileUpXY[dst] = v.x;
if (relUpY0 < tileUpH - 1)
s_tileUpXY[dst + tileUpW] = v.y;
}
else
{
// Write directly into output buffer.
if ((uint32_t)x < p.yShape.x)
{
int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
}
}
}
}
}
else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
{
// Full upsampling filter.
if (up == 2)
{
// 2 x 2-wide.
__syncthreads();
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
{
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
int src0 = relInX0 + tileInW * relInY0;
int tap0y = (relInY0 * up + phaseInY - relUpY0);
#define X_LOOP(TAPY, PX) \
for (int sx = 0; sx < fuSize / up; sx++) \
{ \
v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
}
vec4_t v = InternalType<T>::zero_vec4();
if (tap0y == 0 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 0) }
if (tap0y == 0 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 1) }
if (tap0y == 1 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 0) }
if (tap0y == 1 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 1) }
#undef X_LOOP
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
}
else if (signRead) // Read sign and apply.
{
if ((uint32_t)signY < p.sShape.y)
{
int s = 0;
if ((uint32_t)signXb < p.swLimit) s = p.s[si];
if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
s >>= (signX & 3) << 1;
if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[idx + 0] = v.x;
s_tileUpXY[idx + 1] = v.y;
s_tileUpXY[idx + 2] = v.z;
s_tileUpXY[idx + 3] = v.w;
}
}
else if (up == 1)
{
__syncthreads();
uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
{
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write sign.
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f)
{
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp)
{
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
// s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
// s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
s += __shfl_xor(groupMask, s, 1); // Coalesce.
s += __shfl_xor(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
}
}
else
{
// Determine and write sign.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f)
{
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp)
{
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
// s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
// s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
s += __shfl_xor(groupMask, s, 1); // Coalesce.
s += __shfl_xor(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
}
else
{
// Just compute the value.
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
}
}
else if (signRead)
{
// Read sign and apply if within sign tensor bounds.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
{
int s = p.s[si];
s >>= signXo;
if (s & 1) v *= p.slope;
if (s & 2) v = 0.f;
}
}
else // Forward pass with no sign write.
{
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
if (!downInline) // Write into temporary buffer.
s_tileUpXY[idx] = v;
else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
}
}
}
// Downsampling.
if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
{
// Horizontal downsampling.
__syncthreads();
if (down == 4 && tileOutW % 4 == 0)
{
// Calculate 4 pixels at a time.
for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
#pragma unroll
for (int step = 0; step < fdSize; step++)
{
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx+0] = v.x;
s_tileDownX[idx+1] = v.y;
s_tileDownX[idx+2] = v.z;
s_tileDownX[idx+3] = v.w;
}
}
else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
{
// Calculate 2 pixels at a time.
for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int step = 0; step < fdSize; step++)
{
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx+0] = v.x;
s_tileDownX[idx+1] = v.y;
}
}
else
{
// Calculate 1 pixel at a time.
for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src = relUpY * tileUpW + relUpX0;
scalar_t v = 0.f;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
s_tileDownX[idx] = v;
}
}
// Vertical downsampling & store output tile.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
{
int relOutX, relOutY0;
fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileOutW + relOutX;
scalar_t v = 0;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY0;
if (outX < p.yShape.x & outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
}
}
else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
{
// Full downsampling filter.
if (down == 2)
{
// 2-wide.
__syncthreads();
for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
{
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
int relUpX0 = relOutX0 * down;
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int sy = 0; sy < fdSize; sy++)
#pragma unroll
for (int sx = 0; sx < fdSize; sx++)
{
v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
}
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outY < p.yShape.y)
{
index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
}
}
}
else if (down == 1 && !downInline)
{
// Thread per pixel.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
{
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
}
}
}
if (!enableXrep)
break;
}
}
//------------------------------------------------------------------------
// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
template <class T, bool signWrite, bool signRead>
static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
// Indexing.
int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
// Loop to accommodate oversized tensors.
for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
{
// Extract z and w (channel, minibatch index).
int32_t w = q / p.xShape.z;
int32_t z = q - w * p.xShape.z;
// Choose behavior based on sign read/write mode.
if (signWrite)
{
// Process value if in p.x.
uint32_t s = 0;
if (x < p.xShape.x && y < p.xShape.y)
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
// Gain, LReLU, clamp.
v *= p.gain;
if (v < 0.f)
{
v *= p.slope;
s = 1; // Sign.
}
if (fabsf(v) > p.clamp)
{
v = InternalType<T>::clamp(v, p.clamp);
s = 2; // Clamp.
}
*pv = (T)v; // Write value.
}
// Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
// s |= __shfl_xor_sync(m, s, 1); // Distribute.
// s |= __shfl_xor_sync(m, s, 2);
// s |= __shfl_xor_sync(m, s, 4);
// s |= __shfl_xor_sync(m, s, 8);
s |= __shfl_xor(m, s, 1); // Distribute.
s |= __shfl_xor(m, s, 2);
s |= __shfl_xor(m, s, 4);
s |= __shfl_xor(m, s, 8);
// Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
{
uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
((uint32_t*)p.s)[is >> 4] = s;
}
}
else if (signRead)
{
// Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
// Apply sign buffer offset.
uint32_t sx = x + p.sOfs.x;
uint32_t sy = y + p.sOfs.y;
// Read and apply signs if we land inside valid region of sign buffer.
if (sx < p.sShape.x && sy < p.sShape.y)
{
uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
unsigned char s = p.s[is];
s >>= (sx & 3) << 1; // Shift into place.
if (s & 1) // Sign?
v *= p.slope;
if (s & 2) // Clamp?
v = 0.f;
}
*pv = (T)v; // Write value.
}
}
else
{
// Forward pass with no sign write. Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
if (v < 0.f)
v *= p.slope;
if (fabsf(v) > p.clamp)
v = InternalType<T>::clamp(v, p.clamp);
*pv = (T)v; // Write value.
}
}
}
}
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
{
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
{
filtered_lrelu_kernel_spec s = { 0 };
// Return the first matching kernel.
#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
if (sharedKB >= SH) \
if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
{ \
static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
s.setup = (void*)setup_filters_kernel; \
s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
s.tileOut = make_int2(TW, TH); \
s.numWarps = W; \
s.xrep = XR; \
s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
return s; \
}
// Launch parameters for various kernel specializations.
// Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
// Kernels that use more shared memory must be listed before those that use less, for the same reason.
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
#undef CASE
return s; // No kernel found.
}
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include "filtered_lrelu.h"
#include <cstdint>
//------------------------------------------------------------------------
// Helpers.
enum // Filter modes.
{
MODE_SUSD = 0, // Separable upsampling, separable downsampling.
MODE_FUSD = 1, // Full upsampling, separable downsampling.
MODE_SUFD = 2, // Separable upsampling, full downsampling.
MODE_FUFD = 3, // Full upsampling, full downsampling.
};
template <class T> struct InternalType;
template <> struct InternalType<double>
{
typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
__device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
};
template <> struct InternalType<float>
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
template <> struct InternalType<c10::Half>
{
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
};
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
((B)==2) ? ((int)((A)+1) >> 1) : \
((B)==4) ? ((int)((A)+3) >> 2) : \
(((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
{
if ((N & (N-1)) && N <= 256)
y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
else
y = i/N;
x = i - y*N;
}
// Type cast stride before reading it.
template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
{
return *reinterpret_cast<const T*>(&x);
}
//------------------------------------------------------------------------
// Filters, setup kernel, copying function.
#define MAX_FILTER_SIZE 32
// Combined up/down filter buffers so that transfer can be done with one copy.
__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
// Accessors to combined buffers to index up/down filters individually.
#define c_fu (c_fbuf)
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
#define g_fu (g_fbuf)
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
// Set up filters into global memory buffer.
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
{
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
{
int x, y;
fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
if (p.fuShape.y > 0)
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
else
g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
if (p.fdShape.y > 0)
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
else
g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
}
}
// Host function to copy filters written by setup kernel into constant buffer for main kernel.
template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
{
void* src = 0;
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
if (err) return err;
return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
}
//------------------------------------------------------------------------
// Coordinate spaces:
// - Relative to input tensor: inX, inY, tileInX, tileInY
// - Relative to input tile: relInX, relInY, tileInW, tileInH
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
//
// Relationships between coordinate spaces:
// - inX = tileInX + relInX
// - inY = tileInY + relInY
// - relUpX = relInX * up + phaseInX
// - relUpY = relInY * up + phaseInY
// - relUpX = relOutX * down
// - relUpY = relOutY * down
// - outX = tileOutX + relOutX
// - outY = tileOutY + relOutY
extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
{
// Check that we don't try to support non-existing filter modes.
static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
// Static definitions.
typedef typename InternalType<T>::scalar_t scalar_t;
typedef typename InternalType<T>::vec2_t vec2_t;
typedef typename InternalType<T>::vec4_t vec4_t;
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
// Sizes of logical buffers.
const int szIn = tileInH_up * tileInW;
const int szUpX = tileInH_up * tileUpW;
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
const int szDownX = tileUpH * tileOutW;
// Sizes for shared memory arrays.
const int s_buf0_size_base =
(filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
(filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
(filterMode == MODE_FUFD) ? szIn :
-1;
const int s_buf1_size_base =
(filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
(filterMode == MODE_FUSD) ? szUpXY :
(filterMode == MODE_SUFD) ? szUpX :
(filterMode == MODE_FUFD) ? szUpXY :
-1;
// Ensure U128 alignment.
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
// Check at compile time that we don't use too much shared memory.
static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
// Declare shared memory arrays.
scalar_t* s_buf0;
scalar_t* s_buf1;
if (sharedKB <= 48)
{
// Allocate shared memory arrays here.
__shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
s_buf0 = s_buf0_st;
s_buf1 = s_buf0 + s_buf0_size;
}
else
{
// Use the dynamically allocated shared memory array.
s_buf0 = (scalar_t*)s_buf_raw;
s_buf1 = s_buf0 + s_buf0_size;
}
// Pointers to the buffers.
scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
if (filterMode == MODE_SUSD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
s_tileDownX = s_buf1;
}
else if (filterMode == MODE_FUSD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
s_tileDownX = s_buf0;
}
else if (filterMode == MODE_SUFD)
{
s_tileIn = s_buf0;
s_tileUpX = s_buf1;
s_tileUpXY = s_buf0;
}
else if (filterMode == MODE_FUFD)
{
s_tileIn = s_buf0;
s_tileUpXY = s_buf1;
}
// Allow large grids in z direction via per-launch offset.
int channelIdx = blockIdx.z + p.blockZofs;
int batchIdx = channelIdx / p.yShape.z;
channelIdx -= batchIdx * p.yShape.z;
// Offset to output feature map. In bytes.
index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
// Sign shift amount.
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
// Inner tile loop.
#pragma unroll 1
for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
{
// Locate output tile.
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
int tileOutX = tileX * tileOutW;
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
// Locate input tile.
int tmpX = tileOutX * down - p.pad0.x;
int tmpY = tileOutY * down - p.pad0.y;
int tileInX = CEIL_DIV(tmpX, up);
int tileInY = CEIL_DIV(tmpY, up);
const int phaseInX = tileInX * up - tmpX;
const int phaseInY = tileInY * up - tmpY;
// Extra sync if input and output buffers are the same and we are not on first tile.
if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
__syncthreads();
// Load input tile & apply bias. Unrolled.
scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
int idx = threadIdx.x;
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
#pragma unroll
for (int loop = 0; loop < loopCountIN; loop++)
{
int relInX, relInY;
fast_div_mod<tileInW>(relInX, relInY, idx);
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
if (!skip)
s_tileIn[idx] = v;
idx += threadsPerBlock;
}
if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
{
// Horizontal upsampling.
__syncthreads();
if (up == 4)
{
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileIn[src0];
if (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInX == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInX == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
s_tileUpX[dst+2] = v.z;
s_tileUpX[dst+3] = v.w;
}
}
else if (up == 2)
{
bool p0 = (phaseInX == 0);
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
{
int relUpX0, relInY;
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
int relInX0 = relUpX0 / up;
int src0 = relInX0 + tileInW * relInY;
int dst = relInY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileIn[src0];
if (p0) // (phaseInX == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
}
else // (phaseInX == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileIn[src0 + step + 1];
}
}
s_tileUpX[dst+0] = v.x;
s_tileUpX[dst+1] = v.y;
}
}
// Vertical upsampling & nonlinearity.
__syncthreads();
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
if (up == 4)
{
minY -= 3; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
{
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec4_t v = InternalType<T>::zero_vec4();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 3];
v.z += a * (scalar_t)c_fu[step * up + 2];
v.w += a * (scalar_t)c_fu[step * up + 1];
}
}
else if (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.z += a * (scalar_t)c_fu[step * up + 3];
v.w += a * (scalar_t)c_fu[step * up + 2];
}
}
else if (phaseInY == 2)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 2];
v.y += a * (scalar_t)c_fu[step * up + 1];
v.z += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.w += a * (scalar_t)c_fu[step * up + 3];
}
}
else // (phaseInY == 3)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 3];
v.y += a * (scalar_t)c_fu[step * up + 2];
v.z += a * (scalar_t)c_fu[step * up + 1];
v.w += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
index_t si2 = si0 + p.sShape.x * 2;
index_t si3 = si0 + p.sShape.x * 3;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
int sz = __float_as_uint(v.z) >> 31 << 16;
int sw = __float_as_uint(v.w) >> 31 << 24;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (sz) v.z *= p.slope;
if (sw) v.w *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
// Combine signs.
uint32_t s = sx + sy + sw + sz;
s <<= (signX & 3) << 1;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
}
else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit)
{
int ss = (signX & 3) << 1;
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[dst + 0 * tileUpW] = v.x;
if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
}
}
else if (up == 2)
{
minY -= 1; // Adjust according to block height.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
{
int relUpX, relInY0;
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
int relUpY0 = relInY0 * up;
int src0 = relInY0 * tileUpW + relUpX;
int dst = relUpY0 * tileUpW + relUpX;
vec2_t v = InternalType<T>::zero_vec2();
scalar_t a = s_tileUpX[src0];
if (phaseInY == 0)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
v.y += a * (scalar_t)c_fu[step * up + 1];
}
}
else // (phaseInY == 1)
{
#pragma unroll
for (int step = 0; step < fuSize / up; step++)
{
v.x += a * (scalar_t)c_fu[step * up + 1];
v.y += a * (scalar_t)c_fu[step * up + 0];
a = s_tileUpX[src0 + (step + 1) * tileUpW];
}
}
int x = tileOutX * down + relUpX;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
index_t si1 = si0 + p.sShape.x;
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
// Combine signs.
int s = sx + sy;
s <<= signXo;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31 << 0;
int sy = __float_as_uint(v.y) >> 31 << 8;
if (sx) v.x *= p.slope;
if (sy) v.y *= p.slope;
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
// Combine signs.
int s = sx + sy;
s <<= signXo;
s |= __shfl_xor_sync(groupMask, s, 1);
s |= __shfl_xor_sync(groupMask, s, 2);
// Write signs.
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
}
}
}
else if (signRead) // Read signs and apply.
{
if ((uint32_t)signXb < p.swLimit)
{
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
}
if (!downInline)
{
// Write into temporary buffer.
s_tileUpXY[dst] = v.x;
if (relUpY0 < tileUpH - 1)
s_tileUpXY[dst + tileUpW] = v.y;
}
else
{
// Write directly into output buffer.
if ((uint32_t)x < p.yShape.x)
{
int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
}
}
}
}
}
else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
{
// Full upsampling filter.
if (up == 2)
{
// 2 x 2-wide.
__syncthreads();
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
{
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
int src0 = relInX0 + tileInW * relInY0;
int tap0y = (relInY0 * up + phaseInY - relUpY0);
#define X_LOOP(TAPY, PX) \
for (int sx = 0; sx < fuSize / up; sx++) \
{ \
v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
}
vec4_t v = InternalType<T>::zero_vec4();
if (tap0y == 0 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 0) }
if (tap0y == 0 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(0, 1) }
if (tap0y == 1 && phaseInX == 0)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 0) }
if (tap0y == 1 && phaseInX == 1)
#pragma unroll
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
#pragma unroll
X_LOOP(1, 1) }
#undef X_LOOP
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v.x *= (scalar_t)((float)up * (float)up * p.gain);
v.y *= (scalar_t)((float)up * (float)up * p.gain);
v.z *= (scalar_t)((float)up * (float)up * p.gain);
v.w *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write signs.
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
}
}
else
{
// Determine and write signs.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
int sx = __float_as_uint(v.x) >> 31;
int sy = __float_as_uint(v.y) >> 31;
int sz = __float_as_uint(v.z) >> 31;
int sw = __float_as_uint(v.w) >> 31;
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
}
else
{
// Just compute the values.
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
}
}
else if (signRead) // Read sign and apply.
{
if ((uint32_t)signY < p.sShape.y)
{
int s = 0;
if ((uint32_t)signXb < p.swLimit) s = p.s[si];
if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
s >>= (signX & 3) << 1;
if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
}
}
else // Forward pass with no sign write.
{
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
}
s_tileUpXY[idx + 0] = v.x;
s_tileUpXY[idx + 1] = v.y;
s_tileUpXY[idx + 2] = v.z;
s_tileUpXY[idx + 3] = v.w;
}
}
else if (up == 1)
{
__syncthreads();
uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
{
int relUpX0, relUpY0;
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
int x = tileOutX * down + relUpX0;
int y = tileOutY * down + relUpY0;
int signX = x + p.sOfs.x;
int signY = y + p.sOfs.y;
int signZ = blockIdx.z + p.blockZofs;
int signXb = signX >> 2;
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
v *= (scalar_t)((float)up * (float)up * p.gain);
if (signWrite)
{
if (!enableWriteSkip)
{
// Determine and write sign.
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f)
{
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp)
{
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
}
}
else
{
// Determine and write sign.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
{
uint32_t s = 0;
uint32_t signXbit = (1u << signXo);
if (v < 0.f)
{
s = signXbit;
v *= p.slope;
}
if (fabsf(v) > p.clamp)
{
s = signXbit * 2;
v = InternalType<T>::clamp(v, p.clamp);
}
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
p.s[si] = s; // Write.
}
else
{
// Just compute the value.
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
}
}
else if (signRead)
{
// Read sign and apply if within sign tensor bounds.
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
{
int s = p.s[si];
s >>= signXo;
if (s & 1) v *= p.slope;
if (s & 2) v = 0.f;
}
}
else // Forward pass with no sign write.
{
if (v < 0.f) v *= p.slope;
v = InternalType<T>::clamp(v, p.clamp);
}
if (!downInline) // Write into temporary buffer.
s_tileUpXY[idx] = v;
else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
}
}
}
// Downsampling.
if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
{
// Horizontal downsampling.
__syncthreads();
if (down == 4 && tileOutW % 4 == 0)
{
// Calculate 4 pixels at a time.
for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec4_t v = InternalType<T>::zero_vec4();
#pragma unroll
for (int step = 0; step < fdSize; step++)
{
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx+0] = v.x;
s_tileDownX[idx+1] = v.y;
s_tileDownX[idx+2] = v.z;
s_tileDownX[idx+3] = v.w;
}
}
else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
{
// Calculate 2 pixels at a time.
for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src0 = relUpY * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int step = 0; step < fdSize; step++)
{
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
}
s_tileDownX[idx+0] = v.x;
s_tileDownX[idx+1] = v.y;
}
}
else
{
// Calculate 1 pixel at a time.
for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
{
int relOutX0, relUpY;
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
int relUpX0 = relOutX0 * down;
int src = relUpY * tileUpW + relUpX0;
scalar_t v = 0.f;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
s_tileDownX[idx] = v;
}
}
// Vertical downsampling & store output tile.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
{
int relOutX, relOutY0;
fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileOutW + relOutX;
scalar_t v = 0;
#pragma unroll
for (int step = 0; step < fdSize; step++)
v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY0;
if (outX < p.yShape.x & outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
}
}
else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
{
// Full downsampling filter.
if (down == 2)
{
// 2-wide.
__syncthreads();
for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
{
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
int relUpX0 = relOutX0 * down;
int relUpY0 = relOutY0 * down;
int src0 = relUpY0 * tileUpW + relUpX0;
vec2_t v = InternalType<T>::zero_vec2();
#pragma unroll
for (int sy = 0; sy < fdSize; sy++)
#pragma unroll
for (int sx = 0; sx < fdSize; sx++)
{
v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
}
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outY < p.yShape.y)
{
index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
}
}
}
else if (down == 1 && !downInline)
{
// Thread per pixel.
__syncthreads();
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
{
int relOutX0, relOutY0;
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
int outX = tileOutX + relOutX0;
int outY = tileOutY + relOutY0;
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
}
}
}
if (!enableXrep)
break;
}
}
//------------------------------------------------------------------------
// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
template <class T, bool signWrite, bool signRead>
static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
// Indexing.
int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
// Loop to accommodate oversized tensors.
for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
{
// Extract z and w (channel, minibatch index).
int32_t w = q / p.xShape.z;
int32_t z = q - w * p.xShape.z;
// Choose behavior based on sign read/write mode.
if (signWrite)
{
// Process value if in p.x.
uint32_t s = 0;
if (x < p.xShape.x && y < p.xShape.y)
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
// Gain, LReLU, clamp.
v *= p.gain;
if (v < 0.f)
{
v *= p.slope;
s = 1; // Sign.
}
if (fabsf(v) > p.clamp)
{
v = InternalType<T>::clamp(v, p.clamp);
s = 2; // Clamp.
}
*pv = (T)v; // Write value.
}
// Coalesce into threads 0 and 16 of warp.
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
s |= __shfl_xor_sync(m, s, 1); // Distribute.
s |= __shfl_xor_sync(m, s, 2);
s |= __shfl_xor_sync(m, s, 4);
s |= __shfl_xor_sync(m, s, 8);
// Write signs if leader and in p.s.
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
{
uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
((uint32_t*)p.s)[is >> 4] = s;
}
}
else if (signRead)
{
// Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
// Apply sign buffer offset.
uint32_t sx = x + p.sOfs.x;
uint32_t sy = y + p.sOfs.y;
// Read and apply signs if we land inside valid region of sign buffer.
if (sx < p.sShape.x && sy < p.sShape.y)
{
uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
unsigned char s = p.s[is];
s >>= (sx & 3) << 1; // Shift into place.
if (s & 1) // Sign?
v *= p.slope;
if (s & 2) // Clamp?
v = 0.f;
}
*pv = (T)v; // Write value.
}
}
else
{
// Forward pass with no sign write. Process value if in p.x.
if (x < p.xShape.x) // y is always in.
{
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
T* pv = ((T*)p.x) + ix;
scalar_t v = (scalar_t)(*pv);
v *= p.gain;
if (v < 0.f)
v *= p.slope;
if (fabsf(v) > p.clamp)
v = InternalType<T>::clamp(v, p.clamp);
*pv = (T)v; // Write value.
}
}
}
}
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
{
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
{
filtered_lrelu_kernel_spec s = { 0 };
// Return the first matching kernel.
#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
if (sharedKB >= SH) \
if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
{ \
static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
s.setup = (void*)setup_filters_kernel; \
s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
s.tileOut = make_int2(TW, TH); \
s.numWarps = W; \
s.xrep = XR; \
s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
return s; \
}
// Launch parameters for various kernel specializations.
// Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
// Kernels that use more shared memory must be listed before those that use less, for the same reason.
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
#undef CASE
return s; // No kernel found.
}
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <hip/hip_runtime.h>
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct filtered_lrelu_kernel_params
{
// These parameters decide which kernel to use.
int up; // upsampling ratio (1, 2, 4)
int down; // downsampling ratio (1, 2, 4)
int2 fuShape; // [size, 1] | [size, size]
int2 fdShape; // [size, 1] | [size, size]
int _dummy; // Alignment.
// Rest of the parameters.
const void* x; // Input tensor.
void* y; // Output tensor.
const void* b; // Bias tensor.
unsigned char* s; // Sign tensor in/out. NULL if unused.
const float* fu; // Upsampling filter.
const float* fd; // Downsampling filter.
int2 pad0; // Left/top padding.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int flip; // Filter kernel flip for gradient computation.
int tilesXdim; // Original number of horizontal output tiles.
int tilesXrep; // Number of horizontal tiles per CTA.
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
int4 xShape; // [width, height, channel, batch]
int4 yShape; // [width, height, channel, batch]
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
int swLimit; // Active width of sign tensor in bytes.
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
longlong4 yStride; //
int64_t bStride; //
longlong3 fuStride; //
longlong3 fdStride; //
};
struct filtered_lrelu_act_kernel_params
{
void* x; // Input/output, modified in-place.
unsigned char* s; // Sign tensor in/out. NULL if unused.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int4 xShape; // [width, height, channel, batch]
longlong4 xStride; // Input/output tensor strides, same order as in shape.
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
};
//------------------------------------------------------------------------
// CUDA kernel specialization.
struct filtered_lrelu_kernel_spec
{
void* setup; // Function for filter kernel setup.
void* exec; // Function for main operation.
int2 tileOut; // Width/height of launch tile.
int numWarps; // Number of warps per thread block, determines launch block size.
int xrep; // For processing multiple horizontal tiles per thread block.
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
template <bool signWrite, bool signRead> hipError_t copy_filters(hipStream_t stream);
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <cuda_runtime.h>
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct filtered_lrelu_kernel_params
{
// These parameters decide which kernel to use.
int up; // upsampling ratio (1, 2, 4)
int down; // downsampling ratio (1, 2, 4)
int2 fuShape; // [size, 1] | [size, size]
int2 fdShape; // [size, 1] | [size, size]
int _dummy; // Alignment.
// Rest of the parameters.
const void* x; // Input tensor.
void* y; // Output tensor.
const void* b; // Bias tensor.
unsigned char* s; // Sign tensor in/out. NULL if unused.
const float* fu; // Upsampling filter.
const float* fd; // Downsampling filter.
int2 pad0; // Left/top padding.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int flip; // Filter kernel flip for gradient computation.
int tilesXdim; // Original number of horizontal output tiles.
int tilesXrep; // Number of horizontal tiles per CTA.
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
int4 xShape; // [width, height, channel, batch]
int4 yShape; // [width, height, channel, batch]
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
int swLimit; // Active width of sign tensor in bytes.
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
longlong4 yStride; //
int64_t bStride; //
longlong3 fuStride; //
longlong3 fdStride; //
};
struct filtered_lrelu_act_kernel_params
{
void* x; // Input/output, modified in-place.
unsigned char* s; // Sign tensor in/out. NULL if unused.
float gain; // Additional gain factor.
float slope; // Leaky ReLU slope on negative side.
float clamp; // Clamp after nonlinearity.
int4 xShape; // [width, height, channel, batch]
longlong4 xStride; // Input/output tensor strides, same order as in shape.
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
};
//------------------------------------------------------------------------
// CUDA kernel specialization.
struct filtered_lrelu_kernel_spec
{
void* setup; // Function for filter kernel setup.
void* exec; // Function for main operation.
int2 tileOut; // Width/height of launch tile.
int numWarps; // Number of warps per thread block, determines launch block size.
int xrep; // For processing multiple horizontal tiles per thread block.
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
//------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import numpy as np
import torch
import warnings
from .. import custom_ops
from .. import misc
from . import upfirdn2d
from . import bias_act
#----------------------------------------------------------------------------
_plugin = None
def _init():
global _plugin
if _plugin is None:
_plugin = custom_ops.get_plugin(
module_name='filtered_lrelu_plugin',
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
source_dir=os.path.dirname(__file__),
extra_cuda_cflags=['--gpu-max-threads-per-block=1024', '--std=c++17'],
)
return True
def _get_filter_size(f):
if f is None:
return 1, 1
assert isinstance(f, torch.Tensor)
assert 1 <= f.ndim <= 2
return f.shape[-1], f.shape[0] # width, height
def _parse_padding(padding):
if isinstance(padding, int):
padding = [padding, padding]
assert isinstance(padding, (list, tuple))
assert all(isinstance(x, (int, np.integer)) for x in padding)
padding = [int(x) for x in padding]
if len(padding) == 2:
px, py = padding
padding = [px, px, py, py]
px0, px1, py0, py1 = padding
return px0, px1, py0, py1
#----------------------------------------------------------------------------
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
r"""Filtered leaky ReLU for a batch of 2D images.
Performs the following sequence of operations for each channel:
1. Add channel-specific bias if provided (`b`).
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
3. Pad the image with the specified number of zeros on each side (`padding`).
Negative padding corresponds to cropping the image.
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
so that the footprint of all output pixels lies within the input image.
5. Multiply each value by the provided gain factor (`gain`).
6. Apply leaky ReLU activation function to each value.
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
it so that the footprint of all output pixels lies within the input image.
9. Downsample the image by keeping every Nth pixel (`down`).
The fused op is considerably more efficient than performing the same calculation
using standard PyTorch ops. It supports gradients of arbitrary order.
Args:
x: Float32/float16/float64 input tensor of the shape
`[batch_size, num_channels, in_height, in_width]`.
fu: Float32 upsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
fd: Float32 downsampling FIR filter of the shape
`[filter_height, filter_width]` (non-separable),
`[filter_taps]` (separable), or
`None` (identity).
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
as `x`. The length of vector must must match the channel dimension of `x`.
up: Integer upsampling factor (default: 1).
down: Integer downsampling factor. (default: 1).
padding: Padding with respect to the upsampled image. Can be a single number
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
(default: 0).
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
slope: Slope on the negative side of leaky ReLU (default: 0.2).
clamp: Maximum magnitude for leaky ReLU output (default: None).
flip_filter: False = convolution, True = correlation (default: False).
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
Returns:
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda' and _init():
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
#----------------------------------------------------------------------------
@misc.profiled_function
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
existing `upfirdn2n()` and `bias_act()` ops.
"""
assert isinstance(x, torch.Tensor) and x.ndim == 4
fu_w, fu_h = _get_filter_size(fu)
fd_w, fd_h = _get_filter_size(fd)
if b is not None:
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
misc.assert_shape(b, [x.shape[1]])
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
assert slope == float(slope) and slope >= 0
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
# Calculate output size.
batch_size, channels, in_h, in_w = x.shape
in_dtype = x.dtype
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
# Compute using existing ops.
x = bias_act.bias_act(x=x, b=b) # Apply bias.
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
# Check output shape & dtype.
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
assert x.dtype == in_dtype
return x
#----------------------------------------------------------------------------
_filtered_lrelu_cuda_cache = dict()
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
"""
assert isinstance(up, int) and up >= 1
assert isinstance(down, int) and down >= 1
px0, px1, py0, py1 = _parse_padding(padding)
assert gain == float(gain) and gain > 0
gain = float(gain)
assert slope == float(slope) and slope >= 0
slope = float(slope)
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
clamp = float(clamp if clamp is not None else 'inf')
# Lookup from cache.
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
if key in _filtered_lrelu_cuda_cache:
return _filtered_lrelu_cuda_cache[key]
# Forward op.
class FilteredLReluCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
assert isinstance(x, torch.Tensor) and x.ndim == 4
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
if fu is None:
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
if fd is None:
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
assert 1 <= fu.ndim <= 2
assert 1 <= fd.ndim <= 2
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
fu = fu.square()[None]
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
fd = fd.square()[None]
# Missing sign input tensor.
if si is None:
si = torch.empty([0])
# Missing bias tensor.
if b is None:
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
# Construct internal sign tensor only if gradients are needed.
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
# Call C++/Cuda plugin if datatype is supported.
if x.dtype in [torch.float16, torch.float32]:
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
else:
return_code = -1
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
# only the bit-packed sign tensor is retained for gradient computation.
if return_code < 0:
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
# Prepare for gradient computation.
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
ctx.x_shape = x.shape
ctx.y_shape = y.shape
ctx.s_ofs = sx, sy
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
fu, fd, si = ctx.saved_tensors
_, _, xh, xw = ctx.x_shape
_, _, yh, yw = ctx.y_shape
sx, sy = ctx.s_ofs
dx = None # 0
dfu = None; assert not ctx.needs_input_grad[1]
dfd = None; assert not ctx.needs_input_grad[2]
db = None # 3
dsi = None; assert not ctx.needs_input_grad[4]
dsx = None; assert not ctx.needs_input_grad[5]
dsy = None; assert not ctx.needs_input_grad[6]
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
pp = [
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
xw * up - yw * down + px0 - (up - 1),
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
xh * up - yh * down + py0 - (up - 1),
]
gg = gain * (up ** 2) / (down ** 2)
ff = (not flip_filter)
sx = sx - (fu.shape[-1] - 1) + px0
sy = sy - (fu.shape[0] - 1) + py0
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
if ctx.needs_input_grad[3]:
db = dx.sum([0, 2, 3])
return dx, dfu, dfd, db, dsi, dsx, dsy
# Add to cache.
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
return FilteredLReluCuda
#----------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for no signs mode (no gradients required).
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
// Copy filters to constant memory.
template hipError_t copy_filters<false, false>(hipStream_t stream);
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for no signs mode (no gradients required).
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
// Copy filters to constant memory.
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for sign read mode.
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
// Copy filters to constant memory.
template hipError_t copy_filters<false, true>(hipStream_t stream);
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for sign read mode.
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
// Copy filters to constant memory.
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for sign write mode.
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
// Copy filters to constant memory.
template hipError_t copy_filters<true, false>(hipStream_t stream);
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include "filtered_lrelu.cu"
// Template/kernel specializations for sign write mode.
// Full op, 32-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Full op, 64-bit indexing.
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
// Activation/signs only for generic variant. 64-bit indexing.
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
// Copy filters to constant memory.
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
import torch
#----------------------------------------------------------------------------
def fma(a, b, c): # => a * b + c
return _FusedMultiplyAdd.apply(a, b, c)
#----------------------------------------------------------------------------
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
@staticmethod
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
out = torch.addcmul(c, a, b)
ctx.save_for_backward(a, b)
ctx.c_shape = c.shape
return out
@staticmethod
def backward(ctx, dout): # pylint: disable=arguments-differ
a, b = ctx.saved_tensors
c_shape = ctx.c_shape
da = None
db = None
dc = None
if ctx.needs_input_grad[0]:
da = _unbroadcast(dout * b, a.shape)
if ctx.needs_input_grad[1]:
db = _unbroadcast(dout * a, b.shape)
if ctx.needs_input_grad[2]:
dc = _unbroadcast(dout, c_shape)
return da, db, dc
#----------------------------------------------------------------------------
def _unbroadcast(x, shape):
extra_dims = x.ndim - len(shape)
assert extra_dims >= 0
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
if len(dim):
x = x.sum(dim=dim, keepdim=True)
if extra_dims:
x = x.reshape(-1, *x.shape[extra_dims+1:])
assert x.shape == shape
return x
#----------------------------------------------------------------------------
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Custom replacement for `torch.nn.functional.grid_sample` that
supports arbitrarily high order gradients between the input and output.
Only works on 2D images and assumes
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
import torch
from pkg_resources import parse_version
# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access
#----------------------------------------------------------------------------
enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12
#----------------------------------------------------------------------------
def grid_sample(input, grid):
if _should_use_custom_op():
return _GridSample2dForward.apply(input, grid)
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
#----------------------------------------------------------------------------
def _should_use_custom_op():
return enabled
#----------------------------------------------------------------------------
class _GridSample2dForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input, grid):
assert input.ndim == 4
assert grid.ndim == 4
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
ctx.save_for_backward(input, grid)
return output
@staticmethod
def backward(ctx, grad_output):
input, grid = ctx.saved_tensors
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
return grad_input, grad_grid
#----------------------------------------------------------------------------
class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid
@staticmethod
def backward(ctx, grad2_grad_input, grad2_grad_grid):
_ = grad2_grad_grid # unused
grid, = ctx.saved_tensors
grad2_grad_output = None
grad2_input = None
grad2_grid = None
if ctx.needs_input_grad[0]:
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
assert not ctx.needs_input_grad[2]
return grad2_grad_output, grad2_input, grad2_grid
#----------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/HIPContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
{
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
TORCH_CHECK(x.numel() > 0, "x has zero size");
TORCH_CHECK(f.numel() > 0, "f has zero size");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
// Initialize CUDA kernel parameters.
upfirdn2d_kernel_params p;
p.x = x.data_ptr();
p.f = f.data_ptr<float>();
p.y = y.data_ptr();
p.up = make_int2(upx, upy);
p.down = make_int2(downx, downy);
p.pad0 = make_int2(padx0, pady0);
p.flip = (flip) ? 1 : 0;
p.gain = gain;
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
// Choose CUDA kernel.
upfirdn2d_kernel_spec spec;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
{
spec = choose_upfirdn2d_kernel<scalar_t>(p);
});
// Set looping options.
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
p.loopMinor = spec.loopMinor;
p.loopX = spec.loopX;
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
// Compute grid size.
dim3 blockSize, gridSize;
if (spec.tileOutW < 0) // large
{
blockSize = dim3(4, 32, 1);
gridSize = dim3(
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
p.launchMajor);
}
else // small
{
blockSize = dim3(256, 1, 1);
gridSize = dim3(
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
p.launchMajor);
}
// Launch CUDA kernel.
void* args[] = {&p};
AT_CUDA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
return y;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("upfirdn2d", &upfirdn2d);
}
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
{
// Validate arguments.
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
TORCH_CHECK(x.numel() > 0, "x has zero size");
TORCH_CHECK(f.numel() > 0, "f has zero size");
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
// Create output tensor.
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
// Initialize CUDA kernel parameters.
upfirdn2d_kernel_params p;
p.x = x.data_ptr();
p.f = f.data_ptr<float>();
p.y = y.data_ptr();
p.up = make_int2(upx, upy);
p.down = make_int2(downx, downy);
p.pad0 = make_int2(padx0, pady0);
p.flip = (flip) ? 1 : 0;
p.gain = gain;
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
// Choose CUDA kernel.
upfirdn2d_kernel_spec spec;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
{
spec = choose_upfirdn2d_kernel<scalar_t>(p);
});
// Set looping options.
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
p.loopMinor = spec.loopMinor;
p.loopX = spec.loopX;
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
// Compute grid size.
dim3 blockSize, gridSize;
if (spec.tileOutW < 0) // large
{
blockSize = dim3(4, 32, 1);
gridSize = dim3(
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
p.launchMajor);
}
else // small
{
blockSize = dim3(256, 1, 1);
gridSize = dim3(
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
p.launchMajor);
}
// Launch CUDA kernel.
void* args[] = {&p};
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
return y;
}
//------------------------------------------------------------------------
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("upfirdn2d", &upfirdn2d);
}
//------------------------------------------------------------------------
#include "hip/hip_runtime.h"
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
// Helpers.
template <class T> struct InternalType;
template <> struct InternalType<double> { typedef double scalar_t; };
template <> struct InternalType<float> { typedef float scalar_t; };
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
static __device__ __forceinline__ int floor_div(int a, int b)
{
int t = 1 - a / b;
return (a + t * b) / b - t;
}
//------------------------------------------------------------------------
// Generic CUDA implementation for large filters.
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
// Calculate thread index.
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorBase / p.launchMinor;
minorBase -= outY * p.launchMinor;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Setup Y receptive field.
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
if (p.flip)
filterY = p.filterSize.y - 1 - filterY;
// Loop over major, minor, and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
{
int nc = major * p.sizeMinor + minor;
int n = nc / p.inSize.z;
int c = nc - n * p.inSize.z;
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
{
// Setup X receptive field.
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
if (p.flip)
filterX = p.filterSize.x - 1 - filterX;
// Initialize pointers.
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
// Inner loop.
scalar_t v = 0;
for (int y = 0; y < h; y++)
{
for (int x = 0; x < w; x++)
{
v += (scalar_t)(*xp) * (scalar_t)(*fp);
xp += p.inStride.x;
fp += filterStepX;
}
xp += p.inStride.y - w * p.inStride.x;
fp += filterStepY - w * filterStepX;
}
// Store result.
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filters.
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
__shared__ volatile scalar_t sf[filterH][filterW];
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
// Calculate tile index.
int minorBase = blockIdx.x;
int tileOutY = minorBase / p.launchMinor;
minorBase -= tileOutY * p.launchMinor;
minorBase *= loopMinor;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Load filter (flipped).
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
{
int fy = tapIdx / filterW;
int fx = tapIdx - fy * filterW;
scalar_t v = 0;
if (fx < p.filterSize.x & fy < p.filterSize.y)
{
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
}
sf[fy][fx] = v;
}
// Loop over major and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
{
int baseNC = major * p.sizeMinor + minorBase;
int n = baseNC / p.inSize.z;
int baseC = baseNC - n * p.inSize.z;
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
{
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
int tileInX = floor_div(tileMidX, upx);
int tileInY = floor_div(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
{
int relC = inIdx;
int relInX = relC / loopMinor;
int relInY = relInX / tileInW;
relC -= relInX * loopMinor;
relInX -= relInY * tileInW;
int c = baseC + relC;
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
sx[relInY][relInX][relC] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
{
int relC = outIdx;
int relOutX = relC / loopMinor;
int relOutY = relOutX / tileOutW;
relC -= relOutX * loopMinor;
relOutX -= relOutY * tileOutW;
int c = baseC + relC;
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floor_div(midX, upx);
int inY = floor_div(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int filterX = (inX + 1) * upx - midX - 1; // flipped
int filterY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
{
scalar_t v = 0;
#pragma unroll
for (int y = 0; y < filterH / upy; y++)
#pragma unroll
for (int x = 0; x < filterW / upx; x++)
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
{
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
// No up/downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
// 2x upsampling.
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
}
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
}
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
// 2x downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
}
// 4x upsampling.
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
}
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
}
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
}
// 4x downsampling (inefficient).
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
}
return spec;
}
//------------------------------------------------------------------------
// Template specializations.
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <c10/util/Half.h>
#include "upfirdn2d.h"
//------------------------------------------------------------------------
// Helpers.
template <class T> struct InternalType;
template <> struct InternalType<double> { typedef double scalar_t; };
template <> struct InternalType<float> { typedef float scalar_t; };
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
static __device__ __forceinline__ int floor_div(int a, int b)
{
int t = 1 - a / b;
return (a + t * b) / b - t;
}
//------------------------------------------------------------------------
// Generic CUDA implementation for large filters.
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
// Calculate thread index.
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
int outY = minorBase / p.launchMinor;
minorBase -= outY * p.launchMinor;
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
int majorBase = blockIdx.z * p.loopMajor;
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Setup Y receptive field.
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
if (p.flip)
filterY = p.filterSize.y - 1 - filterY;
// Loop over major, minor, and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
{
int nc = major * p.sizeMinor + minor;
int n = nc / p.inSize.z;
int c = nc - n * p.inSize.z;
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
{
// Setup X receptive field.
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
if (p.flip)
filterX = p.filterSize.x - 1 - filterX;
// Initialize pointers.
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
// Inner loop.
scalar_t v = 0;
for (int y = 0; y < h; y++)
{
for (int x = 0; x < w; x++)
{
v += (scalar_t)(*xp) * (scalar_t)(*fp);
xp += p.inStride.x;
fp += filterStepX;
}
xp += p.inStride.y - w * p.inStride.x;
fp += filterStepY - w * filterStepX;
}
// Store result.
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
//------------------------------------------------------------------------
// Specialized CUDA implementation for small filters.
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
{
typedef typename InternalType<T>::scalar_t scalar_t;
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
__shared__ volatile scalar_t sf[filterH][filterW];
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
// Calculate tile index.
int minorBase = blockIdx.x;
int tileOutY = minorBase / p.launchMinor;
minorBase -= tileOutY * p.launchMinor;
minorBase *= loopMinor;
tileOutY *= tileOutH;
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
int majorBase = blockIdx.z * p.loopMajor;
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
return;
// Load filter (flipped).
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
{
int fy = tapIdx / filterW;
int fx = tapIdx - fy * filterW;
scalar_t v = 0;
if (fx < p.filterSize.x & fy < p.filterSize.y)
{
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
}
sf[fy][fx] = v;
}
// Loop over major and X.
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
{
int baseNC = major * p.sizeMinor + minorBase;
int n = baseNC / p.inSize.z;
int baseC = baseNC - n * p.inSize.z;
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
{
// Load input pixels.
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
int tileInX = floor_div(tileMidX, upx);
int tileInY = floor_div(tileMidY, upy);
__syncthreads();
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
{
int relC = inIdx;
int relInX = relC / loopMinor;
int relInY = relInX / tileInW;
relC -= relInX * loopMinor;
relInX -= relInY * tileInW;
int c = baseC + relC;
int inX = tileInX + relInX;
int inY = tileInY + relInY;
scalar_t v = 0;
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
sx[relInY][relInX][relC] = v;
}
// Loop over output pixels.
__syncthreads();
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
{
int relC = outIdx;
int relOutX = relC / loopMinor;
int relOutY = relOutX / tileOutW;
relC -= relOutX * loopMinor;
relOutX -= relOutY * tileOutW;
int c = baseC + relC;
int outX = tileOutX + relOutX;
int outY = tileOutY + relOutY;
// Setup receptive field.
int midX = tileMidX + relOutX * downx;
int midY = tileMidY + relOutY * downy;
int inX = floor_div(midX, upx);
int inY = floor_div(midY, upy);
int relInX = inX - tileInX;
int relInY = inY - tileInY;
int filterX = (inX + 1) * upx - midX - 1; // flipped
int filterY = (inY + 1) * upy - midY - 1; // flipped
// Inner loop.
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
{
scalar_t v = 0;
#pragma unroll
for (int y = 0; y < filterH / upy; y++)
#pragma unroll
for (int x = 0; x < filterW / upx; x++)
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
v *= p.gain;
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
}
}
}
}
}
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
{
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
// No up/downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
// 2x upsampling.
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
}
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
}
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
}
// 2x downsampling.
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
// channels_last
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
}
// 4x upsampling.
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
}
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
}
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
}
// 4x downsampling (inefficient).
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
{
// contiguous
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
}
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
{
// contiguous
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
// channels_last
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
}
return spec;
}
//------------------------------------------------------------------------
// Template specializations.
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
//------------------------------------------------------------------------
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
#include <hip/hip_runtime.h>
//------------------------------------------------------------------------
// CUDA kernel parameters.
struct upfirdn2d_kernel_params
{
const void* x;
const float* f;
void* y;
int2 up;
int2 down;
int2 pad0;
int flip;
float gain;
int4 inSize; // [width, height, channel, batch]
int4 inStride;
int2 filterSize; // [width, height]
int2 filterStride;
int4 outSize; // [width, height, channel, batch]
int4 outStride;
int sizeMinor;
int sizeMajor;
int loopMinor;
int loopMajor;
int loopX;
int launchMinor;
int launchMajor;
};
//------------------------------------------------------------------------
// CUDA kernel specialization.
struct upfirdn2d_kernel_spec
{
void* kernel;
int tileOutW;
int tileOutH;
int loopMinor;
int loopX;
};
//------------------------------------------------------------------------
// CUDA kernel selection.
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
//------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment