Commit 5a2d5d59 authored by Samuli Laine's avatar Samuli Laine
Browse files

Support user-provided mipmaps

parent ce7063f1
...@@ -529,8 +529,8 @@ For 2D textures, the coordinate origin <span class="math inline">(<em>s</em>,  ...@@ -529,8 +529,8 @@ For 2D textures, the coordinate origin <span class="math inline">(<em>s</em>, 
<p>One might wonder if it would have been easier to determine the texture footprints simply from the texture coordinates in adjacent pixels, and skip all this derivative rubbish? In easy cases the answer is yes, but silhouettes, occlusions, and discontinuous texture parameterizations would make this approach rather unreliable in practice. Computing the image-space derivatives analytically keeps everything point-like, local, and well-behaved.</p> <p>One might wonder if it would have been easier to determine the texture footprints simply from the texture coordinates in adjacent pixels, and skip all this derivative rubbish? In easy cases the answer is yes, but silhouettes, occlusions, and discontinuous texture parameterizations would make this approach rather unreliable in practice. Computing the image-space derivatives analytically keeps everything point-like, local, and well-behaved.</p>
<p>It should be noted that computing gradients related to image-space derivatives is somewhat involved and requires additional computation. At the same time, they are often not crucial for the convergence of the training/optimization. Because of this, the primitive operations in nvdiffrast offer options to disable the calculation of these gradients. We're talking about things like <span class="math inline"><em>L</em><em>o</em><em>s</em><em>s</em>/∂(∂{<em>u</em>, <em>v</em>}/∂{<em>X</em>, <em>Y</em>})</span> that may look second-order-ish, but they're not.</p> <p>It should be noted that computing gradients related to image-space derivatives is somewhat involved and requires additional computation. At the same time, they are often not crucial for the convergence of the training/optimization. Because of this, the primitive operations in nvdiffrast offer options to disable the calculation of these gradients. We're talking about things like <span class="math inline"><em>L</em><em>o</em><em>s</em><em>s</em>/∂(∂{<em>u</em>, <em>v</em>}/∂{<em>X</em>, <em>Y</em>})</span> that may look second-order-ish, but they're not.</p>
<h3 id="mipmaps-and-texture-dimensions">Mipmaps and texture dimensions</h3> <h3 id="mipmaps-and-texture-dimensions">Mipmaps and texture dimensions</h3>
<p>Prefiltered texture sampling modes require <a href="https://en.wikipedia.org/wiki/Mipmap">mipmaps</a>, i.e., downsampled versions, of the texture. The texture sampling operation can construct these internally, but there are limits to texture dimensions that need to be considered.</p> <p>Prefiltered texture sampling modes require <a href="https://en.wikipedia.org/wiki/Mipmap">mipmaps</a>, i.e., downsampled versions, of the texture. The texture sampling operation can construct these internally, or you can provide your own mipmap stack, but there are limits to texture dimensions that need to be considered.</p>
<p>Each mipmap level is constructed by averaging 2×2 pixel patches of the preceding level (or of the texture itself for the first mipmap level). The size of the buffer to be averaged therefore has to be divisible by 2 in both directions. There is one exception: side length of 1 is valid, and it will remain as 1 in the downsampling operation.</p> <p>When mipmaps are constructed internally, each mipmap level is constructed by averaging 2×2 pixel patches of the preceding level (or of the texture itself for the first mipmap level). The size of the buffer to be averaged therefore has to be divisible by 2 in both directions. There is one exception: side length of 1 is valid, and it will remain as 1 in the downsampling operation.</p>
<p>For example, a 32×32 texture will produce the following mipmap stack:</p> <p>For example, a 32×32 texture will produce the following mipmap stack:</p>
<div class="image-parent"> <div class="image-parent">
<table> <table>
...@@ -720,6 +720,7 @@ Mip level 5 ...@@ -720,6 +720,7 @@ Mip level 5
</div> </div>
<p>Scaling the atlas to, say, 256×32 pixels would feel silly because the dimensions of the sub-images are perfectly fine, and downsampling the different sub-images together — which would happen after the 5×1 resolution — would not make sense anyway. For this reason, the texture sampling operation allows the user to specify the maximum number of mipmap levels to be constructed and used. In this case, setting <code>max_mip_level=5</code> would stop at the 5×1 mipmap and prevent the error.</p> <p>Scaling the atlas to, say, 256×32 pixels would feel silly because the dimensions of the sub-images are perfectly fine, and downsampling the different sub-images together — which would happen after the 5×1 resolution — would not make sense anyway. For this reason, the texture sampling operation allows the user to specify the maximum number of mipmap levels to be constructed and used. In this case, setting <code>max_mip_level=5</code> would stop at the 5×1 mipmap and prevent the error.</p>
<p>It is a deliberate design choice that nvdiffrast doesn't just stop automatically at a mipmap size it cannot downsample, but requires the user to specify a limit when the texture dimensions are not powers of two. The goal is to avoid bugs where prefiltered texture sampling mysteriously doesn't work due to an oddly sized texture. It would be confusing if a 256×256 texture gave beautifully prefiltered texture samples, a 255×255 texture suddenly had no prefiltering at all, and a 254×254 texture did just a bit of prefiltering (one level) but not more.</p> <p>It is a deliberate design choice that nvdiffrast doesn't just stop automatically at a mipmap size it cannot downsample, but requires the user to specify a limit when the texture dimensions are not powers of two. The goal is to avoid bugs where prefiltered texture sampling mysteriously doesn't work due to an oddly sized texture. It would be confusing if a 256×256 texture gave beautifully prefiltered texture samples, a 255×255 texture suddenly had no prefiltering at all, and a 254×254 texture did just a bit of prefiltering (one level) but not more.</p>
<p>If you compute your own mipmaps, their sizes must follow the scheme described above. There is no need to specify mipmaps all the way to 1×1 resolution, but the stack can end at any point and it will work equivalently to an internally constructed mipmap stack with a <code>max_mip_level</code> limit. Importantly, the gradients of user-provided mipmaps are not propagated automatically to the base texture — naturally so, because nvdiffrast knows nothing about the relation between them. Instead, the tensors that specify the mip levels in a user-provided mipmap stack will receive gradients of their own.</p>
<h3 id="running-on-multiple-gpus">Running on multiple GPUs</h3> <h3 id="running-on-multiple-gpus">Running on multiple GPUs</h3>
<p>Nvdiffrast supports computation on multiple GPUs in both PyTorch and TensorFlow. As is the convention in PyTorch, the operations are always executed on the device on which the input tensors reside. All GPU input tensors must reside on the same device, and the output tensors will unsurprisingly end up on that same device. In addition, the rasterization operation requires that its OpenGL context was created for the correct device. In TensorFlow, the OpenGL context is automatically created on the device of the rasterization operation when it is executed for the first time.</p> <p>Nvdiffrast supports computation on multiple GPUs in both PyTorch and TensorFlow. As is the convention in PyTorch, the operations are always executed on the device on which the input tensors reside. All GPU input tensors must reside on the same device, and the output tensors will unsurprisingly end up on that same device. In addition, the rasterization operation requires that its OpenGL context was created for the correct device. In TensorFlow, the OpenGL context is automatically created on the device of the rasterization operation when it is executed for the first time.</p>
<p>On Windows, nvdiffrast implements OpenGL device selection in a way that can be done only once per process — after one context is created, all future ones will end up on the same GPU. Hence you cannot expect to run the rasterization operation on multiple GPUs within the same process. Trying to do so will either cause a crash or incur a significant performance penalty. However, with PyTorch it is common to distribute computation across GPUs by launching a separate process for each GPU, so this is not a huge concern. Note that any OpenGL context created within the same process, even for something like a GUI window, will prevent changing the device later. Therefore, if you want to run the rasterization operation on other than the default GPU, be sure to create its OpenGL context before initializing any other OpenGL-powered libraries.</p> <p>On Windows, nvdiffrast implements OpenGL device selection in a way that can be done only once per process — after one context is created, all future ones will end up on the same GPU. Hence you cannot expect to run the rasterization operation on multiple GPUs within the same process. Trying to do so will either cause a crash or incur a significant performance penalty. However, with PyTorch it is common to distribute computation across GPUs by launching a separate process for each GPU, so this is not a huge concern. Note that any OpenGL context created within the same process, even for something like a GUI window, will prevent changing the device later. Therefore, if you want to run the rasterization operation on other than the default GPU, be sure to create its OpenGL context before initializing any other OpenGL-powered libraries.</p>
...@@ -900,8 +901,12 @@ must have shape [minibatch_size, height, width, 2]. When sampling a cube map ...@@ -900,8 +901,12 @@ must have shape [minibatch_size, height, width, 2]. When sampling a cube map
texture, must have shape [minibatch_size, height, width, 3].</td></tr><tr class="arg"><td class="argname">uv_da</td><td class="arg_short">(Optional) Tensor containing image-space derivatives of texture coordinates. texture, must have shape [minibatch_size, height, width, 3].</td></tr><tr class="arg"><td class="argname">uv_da</td><td class="arg_short">(Optional) Tensor containing image-space derivatives of texture coordinates.
Must have same shape as <code>uv</code> except for the last dimension that is to be twice Must have same shape as <code>uv</code> except for the last dimension that is to be twice
as long.</td></tr><tr class="arg"><td class="argname">mip_level_bias</td><td class="arg_short">(Optional) Per-pixel bias for mip level selection. If <code>uv_da</code> is omitted, as long.</td></tr><tr class="arg"><td class="argname">mip_level_bias</td><td class="arg_short">(Optional) Per-pixel bias for mip level selection. If <code>uv_da</code> is omitted,
determines mip level directly. Must have shape [minibatch_size, height, width].</td></tr><tr class="arg"><td class="argname">mip</td><td class="arg_short">(Optional) Preconstructed mipmap stack from a <code>texture_construct_mip()</code> call. If not determines mip level directly. Must have shape [minibatch_size, height, width].</td></tr><tr class="arg"><td class="argname">mip</td><td class="arg_short">(Optional) Preconstructed mipmap stack from a <code>texture_construct_mip()</code> call or a list
specified, the mipmap stack is constructed internally and discarded afterwards.</td></tr><tr class="arg"><td class="argname">filter_mode</td><td class="arg_short">Texture filtering mode to be used. Valid values are 'auto', 'nearest', of tensors specifying a custom mipmap stack. Gradients of a custom mipmap stack
are not automatically propagated to base texture but the mipmap tensors will
receive gradients of their own. If a mipmap stack is not specified but the chosen
filter mode requires it, the mipmap stack is constructed internally and
discarded afterwards.</td></tr><tr class="arg"><td class="argname">filter_mode</td><td class="arg_short">Texture filtering mode to be used. Valid values are 'auto', 'nearest',
'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto'
selects 'linear' if neither <code>uv_da</code> or <code>mip_level_bias</code> is specified, and selects 'linear' if neither <code>uv_da</code> or <code>mip_level_bias</code> is specified, and
'linear-mipmap-linear' when at least one of them is specified, these being 'linear-mipmap-linear' when at least one of them is specified, these being
......
...@@ -6,4 +6,4 @@ ...@@ -6,4 +6,4 @@
# distribution of this software and related documentation without an express # distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited. # license agreement from NVIDIA CORPORATION is strictly prohibited.
__version__ = '0.2.1' __version__ = '0.2.2'
...@@ -18,7 +18,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) ...@@ -18,7 +18,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p)
int bufsz = 1024; int bufsz = 1024;
std::string msg = "Mip-map size error - cannot downsample an odd extent greater than 1. Resize the texture so that both spatial extents are powers of two, or limit the number of mip maps using max_mip_level argument.\n"; std::string msg = "Mip-map size error - cannot downsample an odd extent greater than 1. Resize the texture so that both spatial extents are powers of two, or limit the number of mip maps using max_mip_level argument.\n";
int w = p.texWidth; int w = p.texWidth;
int h = p.texHeight; int h = p.texHeight;
bool ew = false; bool ew = false;
...@@ -29,7 +29,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) ...@@ -29,7 +29,7 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p)
msg += "----- ----- ------\n"; msg += "----- ----- ------\n";
snprintf(buf, bufsz, "base %5d %5d\n", w, h); snprintf(buf, bufsz, "base %5d %5d\n", w, h);
msg += buf; msg += buf;
int mipTotal = 0; int mipTotal = 0;
int level = 0; int level = 0;
while ((w|h) > 1 && !(ew || eh)) // Stop at first impossible size. while ((w|h) > 1 && !(ew || eh)) // Stop at first impossible size.
...@@ -59,12 +59,11 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p) ...@@ -59,12 +59,11 @@ void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p)
NVDR_CHECK(0, msg); NVDR_CHECK(0, msg);
} }
int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p) int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets)
{ {
// No levels at all? // No levels at all?
if (p.mipLevelLimit == 0) if (p.mipLevelLimit == 0)
{ {
p.mipOffset[0] = 0;
p.mipLevelMax = 0; p.mipLevelMax = 0;
return 0; return 0;
} }
...@@ -72,14 +71,14 @@ int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p) ...@@ -72,14 +71,14 @@ int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p)
// Current level size. // Current level size.
int w = p.texWidth; int w = p.texWidth;
int h = p.texHeight; int h = p.texHeight;
p.mipOffset[0] = 0;
int mipTotal = 0; int mipTotal = 0;
int level = 0; int level = 0;
int c = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE) ? (p.channels * 6) : p.channels; int c = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE) ? (p.channels * 6) : p.channels;
mipOffsets[0] = 0;
while ((w|h) > 1) while ((w|h) > 1)
{ {
// Current level. // Current level.
level += 1; level += 1;
// Quit if cannot downsample. // Quit if cannot downsample.
...@@ -90,7 +89,7 @@ int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p) ...@@ -90,7 +89,7 @@ int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p)
if (w > 1) w >>= 1; if (w > 1) w >>= 1;
if (h > 1) h >>= 1; if (h > 1) h >>= 1;
p.mipOffset[level] = mipTotal; mipOffsets[level] = mipTotal; // Store the mip offset (#floats).
mipTotal += w * h * p.texDepth * c; mipTotal += w * h * p.texDepth * c;
// Hit the level limit? // Hit the level limit?
......
...@@ -629,8 +629,8 @@ static __forceinline__ __device__ void MipBuildKernelTemplate(const TextureKerne ...@@ -629,8 +629,8 @@ static __forceinline__ __device__ void MipBuildKernelTemplate(const TextureKerne
int pidx_out = p.channels * (px + sz_out.x * (py + sz_out.y * pz)); int pidx_out = p.channels * (px + sz_out.x * (py + sz_out.y * pz));
// Input and output pointers. // Input and output pointers.
const float* pin = (p.mipLevelOut > 1) ? (p.mip + p.mipOffset[p.mipLevelOut - 1]) : p.tex; const float* pin = p.tex[p.mipLevelOut - 1];
float* pout = p.mip + p.mipOffset[p.mipLevelOut]; float* pout = (float*)p.tex[p.mipLevelOut];
// Special case: Input texture height or width is 1. // Special case: Input texture height or width is 1.
if (sz_in.x == 1 || sz_in.y == 1) if (sz_in.x == 1 || sz_in.y == 1)
...@@ -703,7 +703,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer ...@@ -703,7 +703,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer
{ {
int tc = indexTextureNearest<CUBE_MODE>(p, uv, tz); int tc = indexTextureNearest<CUBE_MODE>(p, uv, tz);
tc *= p.channels; tc *= p.channels;
const float* pIn = p.tex; const float* pIn = p.tex[0];
// Copy if valid tc, otherwise output zero. // Copy if valid tc, otherwise output zero.
for (int i=0; i < p.channels; i += C) for (int i=0; i < p.channels; i += C)
...@@ -721,7 +721,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer ...@@ -721,7 +721,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer
// Get texel indices and pointer for level 0. // Get texel indices and pointer for level 0.
int4 tc0 = make_int4(0, 0, 0, 0); int4 tc0 = make_int4(0, 0, 0, 0);
float2 uv0 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc0, level0); float2 uv0 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc0, level0);
const float* pIn0 = level0 ? (p.mip + p.mipOffset[level0]) : p.tex; const float* pIn0 = p.tex[level0];
bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0);
tc0 *= p.channels; tc0 *= p.channels;
...@@ -741,7 +741,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer ...@@ -741,7 +741,7 @@ static __forceinline__ __device__ void TextureFwdKernelTemplate(const TextureKer
// Get texel indices and pointer for level 1. // Get texel indices and pointer for level 1.
int4 tc1 = make_int4(0, 0, 0, 0); int4 tc1 = make_int4(0, 0, 0, 0);
float2 uv1 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc1, level1); float2 uv1 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc1, level1);
const float* pIn1 = level1 ? (p.mip + p.mipOffset[level1]) : p.tex; const float* pIn1 = p.tex[level1];
bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0);
tc1 *= p.channels; tc1 *= p.channels;
...@@ -851,13 +851,13 @@ static __forceinline__ __device__ void MipGradKernelTemplate(const TextureKernel ...@@ -851,13 +851,13 @@ static __forceinline__ __device__ void MipGradKernelTemplate(const TextureKernel
x >>= 1; x >>= 1;
y >>= 1; y >>= 1;
T* pIn = (T*)(p.gradTexMip + p.mipOffset[level] + (x + sz.x * (y + sz.y * pz)) * p.channels); T* pIn = (T*)(p.gradTex[level] + (x + sz.x * (y + sz.y * pz)) * p.channels);
for (int i=0; i < c; i++) for (int i=0; i < c; i++)
accum_from_mem(TEXEL_ACCUM(i * C), sharedStride, pIn[i], w); accum_from_mem(TEXEL_ACCUM(i * C), sharedStride, pIn[i], w);
} }
// Add to main texture gradients. // Add to main texture gradients.
T* pOut = (T*)(p.gradTex + (px + p.texWidth * (py + p.texHeight * pz)) * p.channels); T* pOut = (T*)(p.gradTex[0] + (px + p.texWidth * (py + p.texHeight * pz)) * p.channels);
for (int i=0; i < c; i++) for (int i=0; i < c; i++)
accum_to_mem(pOut[i], TEXEL_ACCUM(i * C), sharedStride); accum_to_mem(pOut[i], TEXEL_ACCUM(i * C), sharedStride);
} }
...@@ -953,7 +953,7 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe ...@@ -953,7 +953,7 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe
return; // Outside texture. return; // Outside texture.
tc *= p.channels; tc *= p.channels;
float* pOut = p.gradTex; float* pOut = p.gradTex[0];
// Accumulate texture gradients. // Accumulate texture gradients.
for (int i=0; i < p.channels; i++) for (int i=0; i < p.channels; i++)
...@@ -977,8 +977,8 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe ...@@ -977,8 +977,8 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe
// Get texel indices and pointers for level 0. // Get texel indices and pointers for level 0.
int4 tc0 = make_int4(0, 0, 0, 0); int4 tc0 = make_int4(0, 0, 0, 0);
float2 uv0 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc0, level0); float2 uv0 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc0, level0);
const float* pIn0 = level0 ? (p.mip + p.mipOffset[level0]) : p.tex; const float* pIn0 = p.tex[level0];
float* pOut0 = level0 ? (p.gradTexMip + p.mipOffset[level0]) : p.gradTex; float* pOut0 = p.gradTex[level0];
bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0); bool corner0 = CUBE_MODE && ((tc0.x | tc0.y | tc0.z | tc0.w) < 0);
tc0 *= p.channels; tc0 *= p.channels;
...@@ -1024,8 +1024,8 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe ...@@ -1024,8 +1024,8 @@ static __forceinline__ __device__ void TextureGradKernelTemplate(const TextureKe
// Get texel indices and pointers for level 1. // Get texel indices and pointers for level 1.
int4 tc1 = make_int4(0, 0, 0, 0); int4 tc1 = make_int4(0, 0, 0, 0);
float2 uv1 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc1, level1); float2 uv1 = indexTextureLinear<CUBE_MODE>(p, uv, tz, tc1, level1);
const float* pIn1 = level1 ? (p.mip + p.mipOffset[level1]) : p.tex; const float* pIn1 = p.tex[level1];
float* pOut1 = level1 ? (p.gradTexMip + p.mipOffset[level1]) : p.gradTex; float* pOut1 = p.gradTex[level1];
bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0); bool corner1 = CUBE_MODE && ((tc1.x | tc1.y | tc1.z | tc1.w) < 0);
tc1 *= p.channels; tc1 *= p.channels;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#define TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8 #define TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT 8
#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH 8 #define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH 8
#define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8 #define TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT 8
#define TEX_MAX_MIP_LEVEL 14 // Currently a texture cannot be larger than 2 GB because we use 32-bit indices everywhere. #define TEX_MAX_MIP_LEVEL 16 // Currently a texture cannot be larger than 2 GB because we use 32-bit indices everywhere.
#define TEX_MODE_NEAREST 0 // Nearest on base level. #define TEX_MODE_NEAREST 0 // Nearest on base level.
#define TEX_MODE_LINEAR 1 // Bilinear on base level. #define TEX_MODE_LINEAR 1 // Bilinear on base level.
#define TEX_MODE_LINEAR_MIPMAP_NEAREST 2 // Bilinear on nearest mip level. #define TEX_MODE_LINEAR_MIPMAP_NEAREST 2 // Bilinear on nearest mip level.
...@@ -38,15 +38,13 @@ ...@@ -38,15 +38,13 @@
struct TextureKernelParams struct TextureKernelParams
{ {
const float* tex; // Incoming texture buffer. const float* tex[TEX_MAX_MIP_LEVEL]; // Incoming texture buffer with mip levels.
const float* uv; // Incoming texcoord buffer. const float* uv; // Incoming texcoord buffer.
const float* uvDA; // Incoming uv pixel diffs or NULL. const float* uvDA; // Incoming uv pixel diffs or NULL.
const float* mipLevelBias; // Incoming mip level bias or NULL. const float* mipLevelBias; // Incoming mip level bias or NULL.
const float* dy; // Incoming output gradient. const float* dy; // Incoming output gradient.
float* mip; // Mip data buffer.
float* out; // Outgoing texture data. float* out; // Outgoing texture data.
float* gradTex; // Outgoing texture gradient. float* gradTex[TEX_MAX_MIP_LEVEL]; // Outgoing texture gradients with mip levels.
float* gradTexMip; // Temporary texture gradients for mip levels > 0.
float* gradUV; // Outgoing texcoord gradient. float* gradUV; // Outgoing texcoord gradient.
float* gradUVDA; // Outgoing texcoord pixel differential gradient. float* gradUVDA; // Outgoing texcoord pixel differential gradient.
float* gradMipLevelBias; // Outgoing mip level bias gradient. float* gradMipLevelBias; // Outgoing mip level bias gradient.
...@@ -63,7 +61,6 @@ struct TextureKernelParams ...@@ -63,7 +61,6 @@ struct TextureKernelParams
int texDepth; // Texture depth. int texDepth; // Texture depth.
int n; // Minibatch size. int n; // Minibatch size.
int mipLevelMax; // Maximum mip level index. Zero if mips disabled. int mipLevelMax; // Maximum mip level index. Zero if mips disabled.
int mipOffset[TEX_MAX_MIP_LEVEL]; // Offsets in mip data. 0: unused, 1+: offset to mip.
int mipLevelOut; // Mip level being calculated in builder kernel. int mipLevelOut; // Mip level being calculated in builder kernel.
}; };
...@@ -71,7 +68,7 @@ struct TextureKernelParams ...@@ -71,7 +68,7 @@ struct TextureKernelParams
// C++ helper function prototypes. // C++ helper function prototypes.
void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p); void raiseMipSizeError(NVDR_CTX_ARGS, const TextureKernelParams& p);
int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p); int calculateMipInfo(NVDR_CTX_ARGS, TextureKernelParams& p, int* mipOffsets);
//------------------------------------------------------------------------ //------------------------------------------------------------------------
// Macros. // Macros.
......
...@@ -97,7 +97,7 @@ struct TextureFwdOp : public OpKernel ...@@ -97,7 +97,7 @@ struct TextureFwdOp : public OpKernel
} }
// Get input pointers. // Get input pointers.
p.tex = tex.flat<float>().data(); p.tex[0] = tex.flat<float>().data();
p.uv = uv.flat<float>().data(); p.uv = uv.flat<float>().data();
p.uvDA = p.enableMip ? uv_da.flat<float>().data() : 0; p.uvDA = p.enableMip ? uv_da.flat<float>().data() : 0;
...@@ -120,10 +120,12 @@ struct TextureFwdOp : public OpKernel ...@@ -120,10 +120,12 @@ struct TextureFwdOp : public OpKernel
channel_div_idx = 1; // Channel count divisible by 2. channel_div_idx = 1; // Channel count divisible by 2.
// Mip-related setup. // Mip-related setup.
float* pmip = 0;
if (p.enableMip) if (p.enableMip)
{ {
// Generate mip offsets. // Generate mip offsets.
int mipTotal = calculateMipInfo(ctx, p); int mipOffsets[TEX_MAX_MIP_LEVEL];
int mipTotal = calculateMipInfo(ctx, p, mipOffsets);
// Mip output tensor. // Mip output tensor.
Tensor* mip_tensor = NULL; Tensor* mip_tensor = NULL;
...@@ -157,7 +159,9 @@ struct TextureFwdOp : public OpKernel ...@@ -157,7 +159,9 @@ struct TextureFwdOp : public OpKernel
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, mip_shape, &mip_tensor)); OP_REQUIRES_OK(ctx, ctx->allocate_output(1, mip_shape, &mip_tensor));
} }
p.mip = mip_tensor->flat<float>().data(); // Pointer to data. pmip = mip_tensor->flat<float>().data(); // Pointer to data.
for (int i=1; i <= p.mipLevelMax; i++)
p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels.
// Build mip levels if needed. // Build mip levels if needed.
if (computeMip) if (computeMip)
...@@ -181,15 +185,15 @@ struct TextureFwdOp : public OpKernel ...@@ -181,15 +185,15 @@ struct TextureFwdOp : public OpKernel
OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.uv & 7), errors::Internal("uv input tensor not aligned to float2"));
if ((p.channels & 3) == 0) if ((p.channels & 3) == 0)
{ {
OP_REQUIRES(ctx, !((uintptr_t)p.tex & 15), errors::Internal("tex input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)p.out & 15), errors::Internal("out output tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.out & 15), errors::Internal("out output tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)p.mip & 15), errors::Internal("mip output tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip output tensor not aligned to float4"));
} }
if ((p.channels & 1) == 0) if ((p.channels & 1) == 0)
{ {
OP_REQUIRES(ctx, !((uintptr_t)p.tex & 7), errors::Internal("tex input tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)p.out & 7), errors::Internal("out output tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.out & 7), errors::Internal("out output tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)p.mip & 7), errors::Internal("mip output tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip output tensor not aligned to float2"));
} }
if (!cube_mode) if (!cube_mode)
OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.uvDA & 15), errors::Internal("uv_da input tensor not aligned to float4"));
...@@ -278,7 +282,7 @@ struct TextureGradOp : public OpKernel ...@@ -278,7 +282,7 @@ struct TextureGradOp : public OpKernel
TextureKernelParams& p = m_attribs; TextureKernelParams& p = m_attribs;
cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream(); cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE); bool cube_mode = (p.boundaryMode == TEX_BOUNDARY_MODE_CUBE);
// Get input. // Get input.
const Tensor& tex = ctx->input(0); const Tensor& tex = ctx->input(0);
const Tensor& uv = ctx->input(1); const Tensor& uv = ctx->input(1);
...@@ -325,13 +329,13 @@ struct TextureGradOp : public OpKernel ...@@ -325,13 +329,13 @@ struct TextureGradOp : public OpKernel
else else
OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode")); OP_REQUIRES(ctx, uv_da.dims() == 4 && uv_da.dim_size(0) == p.n && uv_da.dim_size(1) == p.imgHeight && uv_da.dim_size(2) == p.imgWidth && uv_da.dim_size(3) == 6, errors::InvalidArgument("uv_da must have shape [minibatch_size, height, width, 6] in cube map mode"));
} }
// Get input pointers. // Get input pointers.
p.tex = tex.flat<float>().data(); p.tex[0] = tex.flat<float>().data();
p.uv = uv.flat<float>().data(); p.uv = uv.flat<float>().data();
p.dy = dy.flat<float>().data(); p.dy = dy.flat<float>().data();
p.uvDA = p.enableMip ? uv_da.flat<float>().data() : 0; p.uvDA = p.enableMip ? uv_da.flat<float>().data() : 0;
p.mip = p.enableMip ? (float*)mip.flat<float>().data() : 0; float* pmip = p.enableMip ? (float*)mip.flat<float>().data() : 0;
// Allocate output tensor for tex gradient. // Allocate output tensor for tex gradient.
Tensor* grad_tex_tensor = NULL; Tensor* grad_tex_tensor = NULL;
...@@ -343,7 +347,7 @@ struct TextureGradOp : public OpKernel ...@@ -343,7 +347,7 @@ struct TextureGradOp : public OpKernel
grad_tex_shape.AddDim(p.texWidth); grad_tex_shape.AddDim(p.texWidth);
grad_tex_shape.AddDim(p.channels); grad_tex_shape.AddDim(p.channels);
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_tex_shape, &grad_tex_tensor)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, grad_tex_shape, &grad_tex_tensor));
p.gradTex = grad_tex_tensor->flat<float>().data(); p.gradTex[0] = grad_tex_tensor->flat<float>().data();
// Allocate output tensor for uv gradient. // Allocate output tensor for uv gradient.
if (p.filterMode != TEX_MODE_NEAREST) if (p.filterMode != TEX_MODE_NEAREST)
...@@ -376,26 +380,33 @@ struct TextureGradOp : public OpKernel ...@@ -376,26 +380,33 @@ struct TextureGradOp : public OpKernel
// Mip-related setup. // Mip-related setup.
Tensor grad_mip_tensor; Tensor grad_mip_tensor;
float* pgradMip = 0;
if (p.enableMip) if (p.enableMip)
{ {
// Generate mip offsets. // Generate mip offsets.
int mipTotal = calculateMipInfo(ctx, p); int mipOffsets[TEX_MAX_MIP_LEVEL];
int mipTotal = calculateMipInfo(ctx, p, mipOffsets);
// Get space for temporary mip gradients. // Get space for temporary mip gradients.
TensorShape grad_mip_shape; TensorShape grad_mip_shape;
grad_mip_shape.AddDim(mipTotal); grad_mip_shape.AddDim(mipTotal);
ctx->allocate_temp(DT_FLOAT, grad_mip_shape, &grad_mip_tensor); ctx->allocate_temp(DT_FLOAT, grad_mip_shape, &grad_mip_tensor);
p.gradTexMip = grad_mip_tensor.flat<float>().data(); pgradMip = grad_mip_tensor.flat<float>().data();
for (int i=1; i <= p.mipLevelMax; i++)
{
p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels.
p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients.
}
// Clear mip gradients. // Clear mip gradients.
OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradTexMip, 0, mipTotal * sizeof(float), stream)); OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(pgradMip, 0, mipTotal * sizeof(float), stream));
} }
// Initialize texture gradients to zero. // Initialize texture gradients to zero.
int texBytes = p.texHeight * p.texWidth * p.texDepth * p.channels * sizeof(float); int texBytes = p.texHeight * p.texWidth * p.texDepth * p.channels * sizeof(float);
if (cube_mode) if (cube_mode)
texBytes *= 6; texBytes *= 6;
OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradTex, 0, texBytes, stream)); OP_CHECK_CUDA_ERROR(ctx, cudaMemsetAsync(p.gradTex[0], 0, texBytes, stream));
// Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned.
if (!cube_mode) if (!cube_mode)
...@@ -412,17 +423,19 @@ struct TextureGradOp : public OpKernel ...@@ -412,17 +423,19 @@ struct TextureGradOp : public OpKernel
} }
if ((p.channels & 3) == 0) if ((p.channels & 3) == 0)
{ {
OP_REQUIRES(ctx, !((uintptr_t)p.tex & 15), errors::Internal("tex input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 15), errors::Internal("tex input tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)p.gradTex & 15), errors::Internal("grad_tex output tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 15), errors::Internal("grad_tex output tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)p.dy & 15), errors::Internal("dy input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)p.dy & 15), errors::Internal("dy input tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)p.mip & 15), errors::Internal("mip input tensor not aligned to float4")); OP_REQUIRES(ctx, !((uintptr_t)pmip & 15), errors::Internal("mip input tensor not aligned to float4"));
OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 15), errors::Internal("internal mip gradient tensor not aligned to float4"));
} }
if ((p.channels & 1) == 0) if ((p.channels & 1) == 0)
{ {
OP_REQUIRES(ctx, !((uintptr_t)p.tex & 7), errors::Internal("tex input tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.tex[0] & 7), errors::Internal("tex input tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)p.gradTex & 7), errors::Internal("grad_tex output tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.gradTex[0] & 7), errors::Internal("grad_tex output tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy output tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)p.dy & 7), errors::Internal("dy output tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)p.mip & 7), errors::Internal("mip input tensor not aligned to float2")); OP_REQUIRES(ctx, !((uintptr_t)pmip & 7), errors::Internal("mip input tensor not aligned to float2"));
OP_REQUIRES(ctx, !((uintptr_t)pgradMip & 7), errors::Internal("internal mip gradient tensor not aligned to float2"));
} }
// Choose launch parameters for main gradient kernel. // Choose launch parameters for main gradient kernel.
...@@ -430,7 +443,7 @@ struct TextureGradOp : public OpKernel ...@@ -430,7 +443,7 @@ struct TextureGradOp : public OpKernel
dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight); dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_KERNEL_BLOCK_HEIGHT, p.imgWidth, p.imgHeight);
dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n); dim3 gridSize = getLaunchGridSize(blockSize, p.imgWidth, p.imgHeight, p.n);
void* func_tbl[TEX_MODE_COUNT * 2] = { void* func_tbl[TEX_MODE_COUNT * 2] = {
(void*)TextureGradKernelNearest, (void*)TextureGradKernelNearest,
(void*)TextureGradKernelLinear, (void*)TextureGradKernelLinear,
(void*)TextureGradKernelLinearMipmapNearest, (void*)TextureGradKernelLinearMipmapNearest,
......
...@@ -324,26 +324,29 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None): ...@@ -324,26 +324,29 @@ def interpolate(attr, rast, tri, rast_db=None, diff_attrs=None):
# Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled. # Linear-mipmap-linear and linear-mipmap-nearest: Mipmaps enabled.
class _texture_func_mip(torch.autograd.Function): class _texture_func_mip(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip, filter_mode_enum, boundary_mode_enum): def forward(ctx, filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack):
empty = torch.tensor([])
if uv_da is None: if uv_da is None:
uv_da = torch.tensor([]) uv_da = empty
if mip_level_bias is None: if mip_level_bias is None:
mip_level_bias = torch.tensor([]) mip_level_bias = empty
out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip, filter_mode_enum, boundary_mode_enum) if mip_wrapper is None:
ctx.save_for_backward(tex, uv, uv_da, mip_level_bias) mip_wrapper = _get_plugin().TextureMipWrapper()
ctx.saved_misc = filter_mode, mip, filter_mode_enum, boundary_mode_enum out = _get_plugin().texture_fwd_mip(tex, uv, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
ctx.save_for_backward(tex, uv, uv_da, mip_level_bias, *mip_stack)
ctx.saved_misc = filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum
return out return out
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
tex, uv, uv_da, mip_level_bias = ctx.saved_variables tex, uv, uv_da, mip_level_bias, *mip_stack = ctx.saved_variables
filter_mode, mip, filter_mode_enum, boundary_mode_enum = ctx.saved_misc filter_mode, mip_wrapper, filter_mode_enum, boundary_mode_enum = ctx.saved_misc
if filter_mode == 'linear-mipmap-linear': if filter_mode == 'linear-mipmap-linear':
g_tex, g_uv, g_uv_da, g_mip_level_bias = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip, filter_mode_enum, boundary_mode_enum) g_tex, g_uv, g_uv_da, g_mip_level_bias, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
return None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None return (None, g_tex, g_uv, g_uv_da, g_mip_level_bias, None, None, None) + tuple(g_mip_stack)
else: # linear-mipmap-nearest else: # linear-mipmap-nearest
g_tex, g_uv = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip, filter_mode_enum, boundary_mode_enum) g_tex, g_uv, g_mip_stack = _get_plugin().texture_grad_linear_mipmap_nearest(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode_enum, boundary_mode_enum)
return None, g_tex, g_uv, None, None, None, None, None return (None, g_tex, g_uv, None, None, None, None, None) + tuple(g_mip_stack)
# Linear and nearest: Mipmaps disabled. # Linear and nearest: Mipmaps disabled.
class _texture_func(torch.autograd.Function): class _texture_func(torch.autograd.Function):
...@@ -386,8 +389,12 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut ...@@ -386,8 +389,12 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut
as long. as long.
mip_level_bias: (Optional) Per-pixel bias for mip level selection. If `uv_da` is omitted, mip_level_bias: (Optional) Per-pixel bias for mip level selection. If `uv_da` is omitted,
determines mip level directly. Must have shape [minibatch_size, height, width]. determines mip level directly. Must have shape [minibatch_size, height, width].
mip: (Optional) Preconstructed mipmap stack from a `texture_construct_mip()` call. If not mip: (Optional) Preconstructed mipmap stack from a `texture_construct_mip()` call or a list
specified, the mipmap stack is constructed internally and discarded afterwards. of tensors specifying a custom mipmap stack. Gradients of a custom mipmap stack
are not automatically propagated to base texture but the mipmap tensors will
receive gradients of their own. If a mipmap stack is not specified but the chosen
filter mode requires it, the mipmap stack is constructed internally and
discarded afterwards.
filter_mode: Texture filtering mode to be used. Valid values are 'auto', 'nearest', filter_mode: Texture filtering mode to be used. Valid values are 'auto', 'nearest',
'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto'
selects 'linear' if neither `uv_da` or `mip_level_bias` is specified, and selects 'linear' if neither `uv_da` or `mip_level_bias` is specified, and
...@@ -437,14 +444,20 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut ...@@ -437,14 +444,20 @@ def texture(tex, uv, uv_da=None, mip_level_bias=None, mip=None, filter_mode='aut
# Construct a mipmap if necessary. # Construct a mipmap if necessary.
if 'mipmap' in filter_mode: if 'mipmap' in filter_mode:
mip_wrapper, mip_stack = None, []
if mip is not None: if mip is not None:
assert isinstance(mip, _get_plugin().TextureMipWrapper) assert isinstance(mip, (_get_plugin().TextureMipWrapper, list))
if isinstance(mip, list):
assert all(isinstance(x, torch.Tensor) for x in mip)
mip_stack = mip
else:
mip_wrapper = mip
else: else:
mip = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube') mip_wrapper = _get_plugin().texture_construct_mip(tex, max_mip_level, boundary_mode == 'cube')
# Choose stub. # Choose stub.
if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest': if filter_mode == 'linear-mipmap-linear' or filter_mode == 'linear-mipmap-nearest':
return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip, filter_mode_enum, boundary_mode_enum) return _texture_func_mip.apply(filter_mode, tex, uv, uv_da, mip_level_bias, mip_wrapper, filter_mode_enum, boundary_mode_enum, *mip_stack)
else: else:
return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum) return _texture_func.apply(filter_mode, tex, uv, filter_mode_enum, boundary_mode_enum)
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#define OP_RETURN_TT std::tuple<torch::Tensor, torch::Tensor> #define OP_RETURN_TT std::tuple<torch::Tensor, torch::Tensor>
#define OP_RETURN_TTT std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> #define OP_RETURN_TTT std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
#define OP_RETURN_TTTT std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> #define OP_RETURN_TTTT std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
#define OP_RETURN_TTV std::tuple<torch::Tensor, torch::Tensor, std::vector<torch::Tensor> >
#define OP_RETURN_TTTTV std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> >
OP_RETURN_TT rasterize_fwd (RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple<int, int> resolution, torch::Tensor ranges); OP_RETURN_TT rasterize_fwd (RasterizeGLStateWrapper& stateWrapper, torch::Tensor pos, torch::Tensor tri, std::tuple<int, int> resolution, torch::Tensor ranges);
OP_RETURN_T rasterize_grad (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy); OP_RETURN_T rasterize_grad (torch::Tensor pos, torch::Tensor tri, torch::Tensor out, torch::Tensor dy);
...@@ -27,11 +29,11 @@ OP_RETURN_TT interpolate_grad (torch::Tensor attr, tor ...@@ -27,11 +29,11 @@ OP_RETURN_TT interpolate_grad (torch::Tensor attr, tor
OP_RETURN_TTT interpolate_grad_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector<int>& diff_attrs_vec); OP_RETURN_TTT interpolate_grad_da (torch::Tensor attr, torch::Tensor rast, torch::Tensor tri, torch::Tensor dy, torch::Tensor rast_db, torch::Tensor dda, bool diff_attrs_all, std::vector<int>& diff_attrs_vec);
TextureMipWrapper texture_construct_mip (torch::Tensor tex, int max_mip_level, bool cube_mode); TextureMipWrapper texture_construct_mip (torch::Tensor tex, int max_mip_level, bool cube_mode);
OP_RETURN_T texture_fwd (torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode); OP_RETURN_T texture_fwd (torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode);
OP_RETURN_T texture_fwd_mip (torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip, int filter_mode, int boundary_mode); OP_RETURN_T texture_fwd_mip (torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode);
OP_RETURN_T texture_grad_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); OP_RETURN_T texture_grad_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode);
OP_RETURN_TT texture_grad_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode); OP_RETURN_TT texture_grad_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode);
OP_RETURN_TT texture_grad_linear_mipmap_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip, int filter_mode, int boundary_mode); OP_RETURN_TTV texture_grad_linear_mipmap_nearest (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode);
OP_RETURN_TTTT texture_grad_linear_mipmap_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip, int filter_mode, int boundary_mode); OP_RETURN_TTTTV texture_grad_linear_mipmap_linear (torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode);
TopologyHashWrapper antialias_construct_topology_hash (torch::Tensor tri); TopologyHashWrapper antialias_construct_topology_hash (torch::Tensor tri);
OP_RETURN_TT antialias_fwd (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash); OP_RETURN_TT antialias_fwd (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, TopologyHashWrapper topology_hash);
OP_RETURN_TT antialias_grad (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer); OP_RETURN_TT antialias_grad (torch::Tensor color, torch::Tensor rast, torch::Tensor pos, torch::Tensor tri, torch::Tensor dy, torch::Tensor work_buffer);
...@@ -43,7 +45,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -43,7 +45,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::class_<RasterizeGLStateWrapper>(m, "RasterizeGLStateWrapper").def(pybind11::init<bool, bool, int>()) pybind11::class_<RasterizeGLStateWrapper>(m, "RasterizeGLStateWrapper").def(pybind11::init<bool, bool, int>())
.def("set_context", &RasterizeGLStateWrapper::setContext) .def("set_context", &RasterizeGLStateWrapper::setContext)
.def("release_context", &RasterizeGLStateWrapper::releaseContext); .def("release_context", &RasterizeGLStateWrapper::releaseContext);
pybind11::class_<TextureMipWrapper>(m, "TextureMipWrapper"); pybind11::class_<TextureMipWrapper>(m, "TextureMipWrapper").def(pybind11::init<>());
pybind11::class_<TopologyHashWrapper>(m, "TopologyHashWrapper"); pybind11::class_<TopologyHashWrapper>(m, "TopologyHashWrapper");
// Plumbing to torch/c10 logging system. // Plumbing to torch/c10 logging system.
......
...@@ -125,15 +125,18 @@ TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bo ...@@ -125,15 +125,18 @@ TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bo
p.channels = tex.size(cube_mode ? 4 : 3); p.channels = tex.size(cube_mode ? 4 : 3);
// Set texture pointer. // Set texture pointer.
p.tex = tex.data_ptr<float>(); p.tex[0] = tex.data_ptr<float>();
// Set mip offsets and calculate total size. // Generate mip offsets and calculate total size.
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p); int mipOffsets[TEX_MAX_MIP_LEVEL];
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets);
// Allocate and set mip tensor. // Allocate and set mip tensor.
torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
torch::Tensor mip = torch::empty({mipTotal}, opts); torch::Tensor mip = torch::empty({mipTotal}, opts);
p.mip = mip.data_ptr<float>(); float* pmip = mip.data_ptr<float>();
for (int i=1; i <= p.mipLevelMax; i++)
p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels.
// Choose kernel variants based on channel count. // Choose kernel variants based on channel count.
void* args[] = {&p}; void* args[] = {&p};
...@@ -157,24 +160,25 @@ TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bo ...@@ -157,24 +160,25 @@ TextureMipWrapper texture_construct_mip(torch::Tensor tex, int max_mip_level, bo
} }
// Return the mip tensor in a wrapper. // Return the mip tensor in a wrapper.
TextureMipWrapper mip_wrap; TextureMipWrapper mip_wrapper;
mip_wrap.mip = mip; mip_wrapper.mip = mip;
mip_wrap.max_mip_level = max_mip_level; mip_wrapper.max_mip_level = max_mip_level;
mip_wrap.texture_size = tex.sizes().vec(); mip_wrapper.texture_size = tex.sizes().vec();
mip_wrap.cube_mode = cube_mode; mip_wrapper.cube_mode = cube_mode;
return mip_wrap; return mip_wrapper;
} }
//------------------------------------------------------------------------ //------------------------------------------------------------------------
// Forward op. // Forward op.
torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrap, int filter_mode, int boundary_mode) torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode)
{ {
const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); const at::cuda::OptionalCUDAGuard device_guard(device_of(tex));
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TextureKernelParams p = {}; // Initialize all fields to zero. TextureKernelParams p = {}; // Initialize all fields to zero.
torch::Tensor& mip = mip_wrap.mip; // Unwrap. bool has_mip_stack = (mip_stack.size() > 0);
int max_mip_level = mip_wrap.max_mip_level; torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap.
int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level;
set_modes(p, filter_mode, boundary_mode, max_mip_level); set_modes(p, filter_mode, boundary_mode, max_mip_level);
// See if we have these tensors or not. // See if we have these tensors or not.
...@@ -184,7 +188,7 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -184,7 +188,7 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
if (p.enableMip) if (p.enableMip)
{ {
NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input");
NVDR_CHECK(mip.defined(), "mipmapping filter mode requires mip tensor input"); NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input");
} }
// Check inputs. // Check inputs.
...@@ -193,9 +197,18 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -193,9 +197,18 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
NVDR_CHECK_F32(tex, uv); NVDR_CHECK_F32(tex, uv);
if (p.enableMip) if (p.enableMip)
{ {
NVDR_CHECK_DEVICE(mip); if (has_mip_stack)
NVDR_CHECK_CONTIGUOUS(mip); {
NVDR_CHECK_F32(mip); TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device");
nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors");
nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors");
}
else
{
NVDR_CHECK_DEVICE(mip_w);
NVDR_CHECK_CONTIGUOUS(mip_w);
NVDR_CHECK_F32(mip_w);
}
if (has_uv_da) if (has_uv_da)
{ {
NVDR_CHECK_DEVICE(uv_da); NVDR_CHECK_DEVICE(uv_da);
...@@ -249,7 +262,7 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -249,7 +262,7 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
} }
// Get input pointers. // Get input pointers.
p.tex = tex.data_ptr<float>(); p.tex[0] = tex.data_ptr<float>();
p.uv = uv.data_ptr<float>(); p.uv = uv.data_ptr<float>();
p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr<float>() : NULL; p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr<float>() : NULL;
p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr<float>() : NULL; p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr<float>() : NULL;
...@@ -268,13 +281,37 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -268,13 +281,37 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
channel_div_idx = 1; // Channel count divisible by 2. channel_div_idx = 1; // Channel count divisible by 2.
// Mip-related setup. // Mip-related setup.
float* pmip = 0;
if (p.enableMip) if (p.enableMip)
{ {
// Generate mip offsets, check mipmap size, and set mip data pointer. if (has_mip_stack)
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p); {
NVDR_CHECK(tex.sizes() == mip_wrap.texture_size && cube_mode == mip_wrap.cube_mode, "mip does not match texture size"); // Custom mip stack supplied. Check that sizes match and assign.
NVDR_CHECK(mip.sizes().size() == 1 && mip.size(0) == mipTotal, "mip tensor size mismatch"); p.mipLevelMax = max_mip_level;
p.mip = mip.data_ptr<float>(); for (int i=1; i <= p.mipLevelMax; i++)
{
torch::Tensor& t = mip_stack[i-1];
int2 sz = mipLevelSize(p, i);
if (!cube_mode)
NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in custom mip stack");
else
NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack");
if (sz.x == 1 && sz.y == 1)
NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack");
p.tex[i] = t.data_ptr<float>();
}
}
else
{
// Generate mip offsets, check mipmap size, and set mip data pointer.
int mipOffsets[TEX_MAX_MIP_LEVEL];
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets);
NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size");
NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "wrapped mip tensor size mismatch");
pmip = mip_w.data_ptr<float>();
for (int i=1; i <= p.mipLevelMax; i++)
p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels.
}
} }
// Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned.
...@@ -282,15 +319,17 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -282,15 +319,17 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2"); NVDR_CHECK(!((uintptr_t)p.uv & 7), "uv input tensor not aligned to float2");
if ((p.channels & 3) == 0) if ((p.channels & 3) == 0)
{ {
NVDR_CHECK(!((uintptr_t)p.tex & 15), "tex input tensor not aligned to float4"); for (int i=1; 0 <= p.mipLevelMax; i++)
NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4"); NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4");
NVDR_CHECK(!((uintptr_t)p.mip & 15), "mip output tensor not aligned to float4"); NVDR_CHECK(!((uintptr_t)p.out & 15), "out output tensor not aligned to float4");
NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4");
} }
if ((p.channels & 1) == 0) if ((p.channels & 1) == 0)
{ {
NVDR_CHECK(!((uintptr_t)p.tex & 7), "tex input tensor not aligned to float2"); for (int i=1; 0 <= p.mipLevelMax; i++)
NVDR_CHECK(!((uintptr_t)p.out & 7), "out output tensor not aligned to float2"); NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2");
NVDR_CHECK(!((uintptr_t)p.mip & 7), "mip output tensor not aligned to float2"); NVDR_CHECK(!((uintptr_t)p.out & 7), "out output tensor not aligned to float2");
NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2");
} }
if (!cube_mode) if (!cube_mode)
NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4"); NVDR_CHECK(!((uintptr_t)p.uvDA & 15), "uv_da input tensor not aligned to float4");
...@@ -372,19 +411,21 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor ...@@ -372,19 +411,21 @@ torch::Tensor texture_fwd_mip(torch::Tensor tex, torch::Tensor uv, torch::Tensor
torch::Tensor texture_fwd(torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode) torch::Tensor texture_fwd(torch::Tensor tex, torch::Tensor uv, int filter_mode, int boundary_mode)
{ {
torch::Tensor empty_tensor; torch::Tensor empty_tensor;
return texture_fwd_mip(tex, uv, empty_tensor, empty_tensor, TextureMipWrapper(), filter_mode, boundary_mode); std::vector<torch::Tensor> empty_vector;
return texture_fwd_mip(tex, uv, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode);
} }
//------------------------------------------------------------------------ //------------------------------------------------------------------------
// Gradient op. // Gradient op.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_grad_linear_mipmap_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrap, int filter_mode, int boundary_mode) std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> > texture_grad_linear_mipmap_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode)
{ {
const at::cuda::OptionalCUDAGuard device_guard(device_of(tex)); const at::cuda::OptionalCUDAGuard device_guard(device_of(tex));
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TextureKernelParams p = {}; // Initialize all fields to zero. TextureKernelParams p = {}; // Initialize all fields to zero.
torch::Tensor& mip = mip_wrap.mip; // Unwrap. bool has_mip_stack = (mip_stack.size() > 0);
int max_mip_level = mip_wrap.max_mip_level; torch::Tensor& mip_w = mip_wrapper.mip; // Unwrap.
int max_mip_level = has_mip_stack ? mip_stack.size() : mip_wrapper.max_mip_level;
set_modes(p, filter_mode, boundary_mode, max_mip_level); set_modes(p, filter_mode, boundary_mode, max_mip_level);
// See if we have these tensors or not. // See if we have these tensors or not.
...@@ -394,7 +435,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -394,7 +435,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
if (p.enableMip) if (p.enableMip)
{ {
NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input"); NVDR_CHECK(has_uv_da || has_mip_level_bias, "mipmapping filter mode requires uv_da and/or mip_level_bias input");
NVDR_CHECK(mip.defined(), "mipmapping filter mode requires mip tensor input"); NVDR_CHECK(has_mip_stack || mip_w.defined(), "mipmapping filter mode requires mip wrapper or mip stack input");
} }
// Check inputs. // Check inputs.
...@@ -403,9 +444,18 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -403,9 +444,18 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
NVDR_CHECK_F32(tex, uv); NVDR_CHECK_F32(tex, uv);
if (p.enableMip) if (p.enableMip)
{ {
NVDR_CHECK_DEVICE(mip); if (has_mip_stack)
NVDR_CHECK_CONTIGUOUS(mip); {
NVDR_CHECK_F32(mip); TORCH_CHECK(at::cuda::check_device(mip_stack), __func__, "(): Mip stack inputs must reside on the correct GPU device");
nvdr_check_contiguous(mip_stack, __func__, "(): Mip stack inputs must be contiguous tensors");
nvdr_check_f32(mip_stack, __func__, "(): Mip stack inputs must be float32 tensors");
}
else
{
NVDR_CHECK_DEVICE(mip_w);
NVDR_CHECK_CONTIGUOUS(mip_w);
NVDR_CHECK_F32(mip_w);
}
if (has_uv_da) if (has_uv_da)
{ {
NVDR_CHECK_DEVICE(uv_da); NVDR_CHECK_DEVICE(uv_da);
...@@ -463,16 +513,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -463,16 +513,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
torch::Tensor dy_ = dy.contiguous(); torch::Tensor dy_ = dy.contiguous();
// Get input pointers. // Get input pointers.
p.tex = tex.data_ptr<float>(); p.tex[0] = tex.data_ptr<float>();
p.uv = uv.data_ptr<float>(); p.uv = uv.data_ptr<float>();
p.dy = dy_.data_ptr<float>(); p.dy = dy_.data_ptr<float>();
p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr<float>() : NULL; p.uvDA = (p.enableMip && has_uv_da) ? uv_da.data_ptr<float>() : NULL;
p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr<float>() : NULL; p.mipLevelBias = (p.enableMip && has_mip_level_bias) ? mip_level_bias.data_ptr<float>() : NULL;
p.mip = p.enableMip ? (float*)mip.data_ptr<float>() : NULL;
// Allocate output tensor for tex gradient. // Allocate output tensor for tex gradient.
torch::Tensor grad_tex = torch::zeros_like(tex); torch::Tensor grad_tex = torch::zeros_like(tex);
p.gradTex = grad_tex.data_ptr<float>(); p.gradTex[0] = grad_tex.data_ptr<float>();
// Allocate output tensor for uv gradient. // Allocate output tensor for uv gradient.
torch::Tensor grad_uv; torch::Tensor grad_uv;
...@@ -511,14 +560,49 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -511,14 +560,49 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
// Mip-related setup. // Mip-related setup.
torch::Tensor grad_mip; torch::Tensor grad_mip;
std::vector<torch::Tensor> grad_mip_stack;
float* pmip = 0;
float* pgradMip = 0;
if (p.enableMip) if (p.enableMip)
{ {
// Generate mip offsets and get space for temporary mip gradients. if (has_mip_stack)
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p); {
NVDR_CHECK(tex.sizes() == mip_wrap.texture_size && cube_mode == mip_wrap.cube_mode, "mip does not match texture size"); // Custom mip stack supplied. Check that sizes match, assign, construct gradient tensors.
NVDR_CHECK(mip.sizes().size() == 1 && mip.size(0) == mipTotal, "mip tensor size mismatch"); p.mipLevelMax = max_mip_level;
grad_mip = torch::zeros_like(mip); for (int i=1; i <= p.mipLevelMax; i++)
p.gradTexMip = grad_mip.data_ptr<float>(); {
torch::Tensor& t = mip_stack[i-1];
int2 sz = mipLevelSize(p, i);
if (!cube_mode)
NVDR_CHECK(t.sizes().size() == 4 && t.size(0) == tex.size(0) && t.size(1) == sz.y && t.size(2) == sz.x && t.size(3) == p.channels, "mip level size mismatch in mip stack");
else
NVDR_CHECK(t.sizes().size() == 5 && t.size(0) == tex.size(0) && t.size(1) == 6 && t.size(2) == sz.y && t.size(3) == sz.x && t.size(4) == p.channels, "mip level size mismatch in mip stack");
if (sz.x == 1 && sz.y == 1)
NVDR_CHECK(i == p.mipLevelMax, "mip level size mismatch in mip stack");
torch::Tensor g = torch::zeros_like(t);
grad_mip_stack.push_back(g);
p.tex[i] = t.data_ptr<float>();
p.gradTex[i] = g.data_ptr<float>();
}
}
else
{
// Generate mip offsets and get space for temporary mip gradients.
int mipOffsets[TEX_MAX_MIP_LEVEL];
int mipTotal = calculateMipInfo(NVDR_CTX_PARAMS, p, mipOffsets);
NVDR_CHECK(tex.sizes() == mip_wrapper.texture_size && cube_mode == mip_wrapper.cube_mode, "mip does not match texture size");
NVDR_CHECK(mip_w.sizes().size() == 1 && mip_w.size(0) == mipTotal, "mip tensor size mismatch");
grad_mip = torch::zeros_like(mip_w);
pmip = (float*)mip_w.data_ptr<float>();
pgradMip = grad_mip.data_ptr<float>();
for (int i=1; i <= p.mipLevelMax; i++)
{
p.tex[i] = pmip + mipOffsets[i]; // Pointers to mip levels.
p.gradTex[i] = pgradMip + mipOffsets[i]; // Pointers to mip gradients.
}
}
} }
// Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned. // Verify that buffers are aligned to allow float2/float4 operations. Unused pointers are zero so always aligned.
...@@ -536,17 +620,25 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -536,17 +620,25 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
} }
if ((p.channels & 3) == 0) if ((p.channels & 3) == 0)
{ {
NVDR_CHECK(!((uintptr_t)p.tex & 15), "tex input tensor not aligned to float4"); for (int i=0; i <= p.mipLevelMax; i++)
NVDR_CHECK(!((uintptr_t)p.gradTex & 15), "grad_tex output tensor not aligned to float4"); {
NVDR_CHECK(!((uintptr_t)p.dy & 15), "dy input tensor not aligned to float4"); NVDR_CHECK(!((uintptr_t)p.tex[i] & 15), "tex or mip input tensor not aligned to float4");
NVDR_CHECK(!((uintptr_t)p.mip & 15), "mip input tensor not aligned to float4"); NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 15), "grad_tex output tensor not aligned to float4");
}
NVDR_CHECK(!((uintptr_t)p.dy & 15), "dy input tensor not aligned to float4");
NVDR_CHECK(!((uintptr_t)pmip & 15), "mip input tensor not aligned to float4");
NVDR_CHECK(!((uintptr_t)pgradMip & 15), "internal mip gradient tensor not aligned to float4");
} }
if ((p.channels & 1) == 0) if ((p.channels & 1) == 0)
{ {
NVDR_CHECK(!((uintptr_t)p.tex & 7), "tex input tensor not aligned to float2"); for (int i=0; i <= p.mipLevelMax; i++)
NVDR_CHECK(!((uintptr_t)p.gradTex & 7), "grad_tex output tensor not aligned to float2"); {
NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy output tensor not aligned to float2"); NVDR_CHECK(!((uintptr_t)p.tex[i] & 7), "tex or mip input tensor not aligned to float2");
NVDR_CHECK(!((uintptr_t)p.mip & 7), "mip input tensor not aligned to float2"); NVDR_CHECK(!((uintptr_t)p.gradTex[i] & 7), "grad_tex output tensor not aligned to float2");
}
NVDR_CHECK(!((uintptr_t)p.dy & 7), "dy output tensor not aligned to float2");
NVDR_CHECK(!((uintptr_t)pmip & 7), "mip input tensor not aligned to float2");
NVDR_CHECK(!((uintptr_t)pgradMip & 7), "internal mip gradient tensor not aligned to float2");
} }
// Choose launch parameters for main gradient kernel. // Choose launch parameters for main gradient kernel.
...@@ -583,8 +675,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -583,8 +675,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
// Launch main gradient kernel. // Launch main gradient kernel.
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream)); NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel(func_tbl[func_idx], gridSize, blockSize, args, 0, stream));
// Launch kernel to pull gradients from mip levels. // Launch kernel to pull gradients from mip levels. Don't do this if mip stack was supplied - individual level gradients are already there.
if (p.enableMip) if (p.enableMip && !has_mip_stack)
{ {
dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight); dim3 blockSize = getLaunchBlockSize(TEX_GRAD_MAX_MIP_KERNEL_BLOCK_WIDTH, TEX_GRAD_MAX_MIP_KERNEL_BLOCK_HEIGHT, p.texWidth, p.texHeight);
dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1)); dim3 gridSize = getLaunchGridSize(blockSize, p.texWidth, p.texHeight, p.texDepth * (cube_mode ? 6 : 1));
...@@ -595,14 +687,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g ...@@ -595,14 +687,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> texture_g
} }
// Return output tensors. // Return output tensors.
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(grad_tex, grad_uv, grad_uv_da, grad_mip_level_bias); return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> >(grad_tex, grad_uv, grad_uv_da, grad_mip_level_bias, grad_mip_stack);
} }
// Version for nearest filter mode. // Version for nearest filter mode.
torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode)
{ {
torch::Tensor empty_tensor; torch::Tensor empty_tensor;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), filter_mode, boundary_mode); std::vector<torch::Tensor> empty_vector;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode);
return std::get<0>(result); return std::get<0>(result);
} }
...@@ -610,15 +703,16 @@ torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::T ...@@ -610,15 +703,16 @@ torch::Tensor texture_grad_nearest(torch::Tensor tex, torch::Tensor uv, torch::T
std::tuple<torch::Tensor, torch::Tensor> texture_grad_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode) std::tuple<torch::Tensor, torch::Tensor> texture_grad_linear(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, int filter_mode, int boundary_mode)
{ {
torch::Tensor empty_tensor; torch::Tensor empty_tensor;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), filter_mode, boundary_mode); std::vector<torch::Tensor> empty_vector;
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> > result = texture_grad_linear_mipmap_linear(tex, uv, dy, empty_tensor, empty_tensor, TextureMipWrapper(), empty_vector, filter_mode, boundary_mode);
return std::tuple<torch::Tensor, torch::Tensor>(std::get<0>(result), std::get<1>(result)); return std::tuple<torch::Tensor, torch::Tensor>(std::get<0>(result), std::get<1>(result));
} }
// Version for linear-mipmap-nearest mode. // Version for linear-mipmap-nearest mode.
std::tuple<torch::Tensor, torch::Tensor> texture_grad_linear_mipmap_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip, int filter_mode, int boundary_mode) std::tuple<torch::Tensor, torch::Tensor, std::vector<torch::Tensor> > texture_grad_linear_mipmap_nearest(torch::Tensor tex, torch::Tensor uv, torch::Tensor dy, torch::Tensor uv_da, torch::Tensor mip_level_bias, TextureMipWrapper mip_wrapper, std::vector<torch::Tensor> mip_stack, int filter_mode, int boundary_mode)
{ {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> result = texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip, filter_mode, boundary_mode); std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::vector<torch::Tensor> > result = texture_grad_linear_mipmap_linear(tex, uv, dy, uv_da, mip_level_bias, mip_wrapper, mip_stack, filter_mode, boundary_mode);
return std::tuple<torch::Tensor, torch::Tensor>(std::get<0>(result), std::get<1>(result)); return std::tuple<torch::Tensor, torch::Tensor, std::vector<torch::Tensor> >(std::get<0>(result), std::get<1>(result), std::get<4>(result));
} }
//------------------------------------------------------------------------ //------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment