Commit 78528e68 authored by Samuli Laine's avatar Samuli Laine
Browse files

Fixes to #59, #62

parent e44c9a29
......@@ -44,6 +44,7 @@ ENV PYOPENGL_PLATFORM egl
COPY docker/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
RUN pip install --upgrade pip
RUN pip install ninja imageio imageio-ffmpeg
COPY nvdiffrast /tmp/pip/nvdiffrast/
......
......@@ -303,30 +303,36 @@ div.image-parent {
<nav id="TOC">
<ul>
<li><a href="#overview">Overview</a></li>
<li><a href="#installation">Installation</a><ul>
<li><a href="#installation">Installation</a>
<ul>
<li><a href="#linux">Linux</a></li>
<li><a href="#windows">Windows</a></li>
</ul></li>
<li><a href="#primitive-operations">Primitive operations</a><ul>
<li><a href="#primitive-operations">Primitive operations</a>
<ul>
<li><a href="#rasterization">Rasterization</a></li>
<li><a href="#interpolation">Interpolation</a></li>
<li><a href="#texturing">Texturing</a></li>
<li><a href="#antialiasing">Antialiasing</a></li>
</ul></li>
<li><a href="#beyond-the-basics">Beyond the basics</a><ul>
<li><a href="#beyond-the-basics">Beyond the basics</a>
<ul>
<li><a href="#coordinate-systems">Coordinate systems</a></li>
<li><a href="#geometry-and-minibatches-range-mode-vs-instanced-mode">Geometry and minibatches: Range mode vs Instanced mode</a></li>
<li><a href="#image-space-derivatives">Image-space derivatives</a></li>
<li><a href="#mipmaps-and-texture-dimensions">Mipmaps and texture dimensions</a></li>
<li><a href="#running-on-multiple-gpus">Running on multiple GPUs</a><ul>
<li><a href="#running-on-multiple-gpus">Running on multiple GPUs</a>
<ul>
<li><a href="#note-on-torch.nn.dataparallel">Note on torch.nn.DataParallel</a></li>
</ul></li>
<li><a href="#rendering-multiple-depth-layers">Rendering multiple depth layers</a></li>
<li><a href="#differences-between-pytorch-and-tensorflow">Differences between PyTorch and TensorFlow</a><ul>
<li><a href="#differences-between-pytorch-and-tensorflow">Differences between PyTorch and TensorFlow</a>
<ul>
<li><a href="#manual-opengl-contexts-in-pytorch">Manual OpenGL contexts in PyTorch</a></li>
</ul></li>
</ul></li>
<li><a href="#samples">Samples</a><ul>
<li><a href="#samples">Samples</a>
<ul>
<li><a href="#triangle.py"><span>triangle.py</span></a></li>
<li><a href="#cube.py"><span>cube.py</span></a></li>
<li><a href="#earth.py"><span>earth.py</span></a></li>
......@@ -346,7 +352,7 @@ This documentation is intended to serve as a user's guide to nvdiffrast. For det
<blockquote>
<strong>Modular Primitives for High-Performance Differentiable Rendering</strong><br> Samuli Laine, Janne Hellsten, Tero Karras, Yeongho Seol, Jaakko Lehtinen, Timo Aila<br> ACM Transactions on Graphics 39(6) (proc. SIGGRAPH Asia 2020)
</blockquote>
<p>Paper: <a href="http://arxiv.org/abs/2011.03277" class="uri">http://arxiv.org/abs/2011.03277</a><br> GitHub: <a href="https://github.com/NVlabs/nvdiffrast" class="uri">https://github.com/NVlabs/nvdiffrast</a></p>
<p>Paper: <a href="http://arxiv.org/abs/2011.03277">http://arxiv.org/abs/2011.03277</a><br> GitHub: <a href="https://github.com/NVlabs/nvdiffrast">https://github.com/NVlabs/nvdiffrast</a></p>
<div class="image-parent">
<div class="image-caption">
<div class="image-row">
......@@ -365,16 +371,16 @@ Examples of things we've done with nvdiffrast
<li>PyTorch 1.6 (recommended) or TensorFlow 1.14. TensorFlow 2.x is currently not supported.</li>
<li>A high-end NVIDIA GPU, NVIDIA drivers, CUDA 10.2 toolkit, and cuDNN 7.6.</li>
</ul>
<p>To download nvdiffrast, either download the repository at <a href="https://github.com/NVlabs/nvdiffrast" class="uri">https://github.com/NVlabs/nvdiffrast</a> as a .zip file, or clone the repository using git:</p>
<div class="sourceCode" id="cb1"><pre class="sourceCode bash"><code class="sourceCode bash"><a class="sourceLine" id="cb1-1" data-line-number="1"><span class="fu">git</span> clone https://github.com/NVlabs/nvdiffrast</a></code></pre></div>
<p>To download nvdiffrast, either download the repository at <a href="https://github.com/NVlabs/nvdiffrast">https://github.com/NVlabs/nvdiffrast</a> as a .zip file, or clone the repository using git:</p>
<div class="sourceCode" id="cb1"><pre class="sourceCode bash"><code class="sourceCode bash"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="fu">git</span> clone https://github.com/NVlabs/nvdiffrast</span></code></pre></div>
<h3 id="linux">Linux</h3>
<p>We recommend running nvdiffrast on <a href="https://www.docker.com/">Docker</a>. To build a Docker image with nvdiffrast and PyTorch 1.6 installed, run:</p>
<div class="sourceCode" id="cb2"><pre class="sourceCode bash"><code class="sourceCode bash"><a class="sourceLine" id="cb2-1" data-line-number="1"><span class="ex">./run_sample.sh</span> --build-container</a></code></pre></div>
<div class="sourceCode" id="cb2"><pre class="sourceCode bash"><code class="sourceCode bash"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="ex">./run_sample.sh</span> --build-container</span></code></pre></div>
<p>We recommend using Ubuntu, as some Linux distributions might not have all the required packages available — at least CentOS is reportedly problematic.</p>
<p>To try out some of the provided code examples, run:</p>
<div class="sourceCode" id="cb3"><pre class="sourceCode bash"><code class="sourceCode bash"><a class="sourceLine" id="cb3-1" data-line-number="1"><span class="ex">./run_sample.sh</span> ./samples/torch/cube.py --resolution 32</a></code></pre></div>
<div class="sourceCode" id="cb3"><pre class="sourceCode bash"><code class="sourceCode bash"><span id="cb3-1"><a href="#cb3-1" aria-hidden="true" tabindex="-1"></a><span class="ex">./run_sample.sh</span> ./samples/torch/cube.py --resolution 32</span></code></pre></div>
<p>Alternatively, if you have all the dependencies taken care of (consult the included Dockerfile for reference), you can install nvdiffrast in your local Python site-packages by running</p>
<div class="sourceCode" id="cb4"><pre class="sourceCode bash"><code class="sourceCode bash"><a class="sourceLine" id="cb4-1" data-line-number="1"><span class="ex">pip</span> install .</a></code></pre></div>
<div class="sourceCode" id="cb4"><pre class="sourceCode bash"><code class="sourceCode bash"><span id="cb4-1"><a href="#cb4-1" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install .</span></code></pre></div>
<p>at the root of the repository. You can also just add the repository root directory to your <code>PYTHONPATH</code>.</p>
<h3 id="windows">Windows</h3>
<p>On Windows, nvdiffrast requires an external compiler for compiling the CUDA kernels. The development was done using Microsoft Visual Studio 2017 Professional Edition, and this version works with both PyTorch and TensorFlow versions of nvdiffrast. VS 2019 Professional Edition has also been confirmed to work with the PyTorch version of nvdiffrast. Other VS editions besides Professional Edition, including the Community Edition, should work but have not been tested.</p>
......@@ -382,11 +388,11 @@ Examples of things we've done with nvdiffrast
<pre><code>&quot;C:\Program Files (x86)\Microsoft Visual Studio\...\...\VC\Auxiliary\Build\vcvars64.bat&quot;</code></pre>
<p>where the exact path depends on the version and edition of VS you have installed.</p>
<p>To install nvdiffrast in your local site-packages, run:</p>
<div class="sourceCode" id="cb6"><pre class="sourceCode bash"><code class="sourceCode bash"><a class="sourceLine" id="cb6-1" data-line-number="1"><span class="co"># Ninja is required run-time to build PyTorch extensions</span></a>
<a class="sourceLine" id="cb6-2" data-line-number="2"><span class="ex">pip</span> install ninja</a>
<a class="sourceLine" id="cb6-3" data-line-number="3"></a>
<a class="sourceLine" id="cb6-4" data-line-number="4"><span class="co"># Run at the root of the repository to install nvdiffrast</span></a>
<a class="sourceLine" id="cb6-5" data-line-number="5"><span class="ex">pip</span> install .</a></code></pre></div>
<div class="sourceCode" id="cb6"><pre class="sourceCode bash"><code class="sourceCode bash"><span id="cb6-1"><a href="#cb6-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Ninja is required run-time to build PyTorch extensions</span></span>
<span id="cb6-2"><a href="#cb6-2" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install ninja</span>
<span id="cb6-3"><a href="#cb6-3" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb6-4"><a href="#cb6-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Run at the root of the repository to install nvdiffrast</span></span>
<span id="cb6-5"><a href="#cb6-5" aria-hidden="true" tabindex="-1"></a><span class="ex">pip</span> install .</span></code></pre></div>
<p>Instead of <code>pip install .</code> you can also just add the repository root directory to your <code>PYTHONPATH</code>.</p>
<h2 id="primitive-operations">Primitive operations</h2>
<p>Nvdiffrast offers four differentiable rendering primitives: <strong>rasterization</strong>, <strong>interpolation</strong>, <strong>texturing</strong>, and <strong>antialiasing</strong>. The operation of the primitives is described here in a platform-agnostic way. Platform-specific documentation can be found in the API reference section.</p>
......@@ -455,7 +461,7 @@ Background replaced with white
</div>
</div>
<p>The middle image above shows the result of texture sampling using the interpolated texture coordinates from the previous step. Why is the background pink? The texture coordinates <span class="math inline">(<em>s</em>, <em>t</em>)</span> read as zero at those pixels, but that is a perfectly valid point to sample the texture. It happens that Spot's texture (left) has pink color at its <span class="math inline">(0, 0)</span> corner, and therefore all pixels in the background obtain that color as a result of the texture sampling operation. On the right, we have replaced the color of the <q>empty</q> pixels with a white color. Here's one way to do this in PyTorch:</p>
<div class="sourceCode" id="cb7"><pre class="sourceCode python"><code class="sourceCode python"><a class="sourceLine" id="cb7-1" data-line-number="1">img_right <span class="op">=</span> torch.where(rast_out[..., <span class="dv">3</span>:] <span class="op">&gt;</span> <span class="dv">0</span>, img_left, torch.tensor(<span class="fl">1.0</span>).cuda())</a></code></pre></div>
<div class="sourceCode" id="cb7"><pre class="sourceCode python"><code class="sourceCode python"><span id="cb7-1"><a href="#cb7-1" aria-hidden="true" tabindex="-1"></a>img_right <span class="op">=</span> torch.where(rast_out[..., <span class="dv">3</span>:] <span class="op">&gt;</span> <span class="dv">0</span>, img_left, torch.tensor(<span class="fl">1.0</span>).cuda())</span></code></pre></div>
<p>where <code>rast_out</code> is the output of the rasterization operation. We simply test if the <span class="math inline"><em>t</em><em>r</em><em>i</em><em>a</em><em>n</em><em>g</em><em>l</em><em>e</em>_<em>i</em><em>d</em></span> field, i.e., channel 3 of the rasterizer output, is greater than zero, indicating that a triangle was rendered in that pixel. If so, we take the color from the textured image, and otherwise we take constant 1.0.</p>
<h3 id="antialiasing">Antialiasing</h3>
<p>The last of the four primitive operations in nvdiffrast is antialiasing. Based on the geometry input (vertex positions and triangles), it will smooth out discontinuties at silhouette edges in a given image. The smoothing is based on a local approximation of coverage — an approximate integral over a pixel is calculated based on the exact location of relevant edges and the point-sampled colors at pixel centers.</p>
......@@ -762,13 +768,13 @@ Third depth layer
</div>
</div>
<p>The API for depth peeling is based on <code>DepthPeeler</code> object that acts as a <a href="https://docs.python.org/3/reference/datamodel.html#context-managers">context manager</a>, and its <code>rasterize_next_layer</code> method. The first call to <code>rasterize_next_layer</code> is equivalent to calling the traditional <code>rasterize</code> function, and subsequent calls report further depth layers. The arguments for rasterization are specified when instantiating the <code>DepthPeeler</code> object. Concretely, your code might look something like this:</p>
<div class="sourceCode" id="cb8"><pre class="sourceCode python"><code class="sourceCode python"><a class="sourceLine" id="cb8-1" data-line-number="1"><span class="cf">with</span> nvdiffrast.torch.DepthPeeler(glctx, pos, tri, resolution) <span class="im">as</span> peeler:</a>
<a class="sourceLine" id="cb8-2" data-line-number="2"> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_layers):</a>
<a class="sourceLine" id="cb8-3" data-line-number="3"> rast, rast_db <span class="op">=</span> peeler.rasterize_next_layer()</a>
<a class="sourceLine" id="cb8-4" data-line-number="4"> (process <span class="kw">or</span> store the results)</a></code></pre></div>
<div class="sourceCode" id="cb8"><pre class="sourceCode python"><code class="sourceCode python"><span id="cb8-1"><a href="#cb8-1" aria-hidden="true" tabindex="-1"></a><span class="cf">with</span> nvdiffrast.torch.DepthPeeler(glctx, pos, tri, resolution) <span class="im">as</span> peeler:</span>
<span id="cb8-2"><a href="#cb8-2" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(num_layers):</span>
<span id="cb8-3"><a href="#cb8-3" aria-hidden="true" tabindex="-1"></a> rast, rast_db <span class="op">=</span> peeler.rasterize_next_layer()</span>
<span id="cb8-4"><a href="#cb8-4" aria-hidden="true" tabindex="-1"></a> (process <span class="kw">or</span> store the results)</span></code></pre></div>
<p>There is no performance penalty compared to the basic rasterization op if you end up extracting only the first depth layer. In other words, the code above with <code>num_layers=1</code> runs exactly as fast as calling <code>rasterize</code> once.</p>
<p>Depth peeling is only supported in the PyTorch version of nvdiffrast. For implementation reasons, depth peeling reserves the OpenGL context so that other rasterization operations cannot be performed while the peeling is ongoing, i.e., inside the <code>with</code> block. Hence you cannot start a nested depth peeling operation or call <code>rasterize</code> inside the <code>with</code> block, unless you use a different OpenGL context.</p>
<p>For the sake of completeness, let us note the following small caveat: Depth peeling relies on depth values to distinguish surface points from each other. Therefore, culling &quot;previously rendered surface points&quot; actually means culling all surface points at the same or closer depth as those rendered into the pixel in previous passes. This matters only if you have multiple layers of geometry at matching depths — if your geometry consists of, say, nothing but two exactly overlapping triangles, you will see one of them in the first pass but never see the other one in subsequent passes, as it's at the exact depth that is already considered done.</p>
<p>For the sake of completeness, let us note the following small caveat: Depth peeling relies on depth values to distinguish surface points from each other. Therefore, culling "previously rendered surface points" actually means culling all surface points at the same or closer depth as those rendered into the pixel in previous passes. This matters only if you have multiple layers of geometry at matching depths — if your geometry consists of, say, nothing but two exactly overlapping triangles, you will see one of them in the first pass but never see the other one in subsequent passes, as it's at the exact depth that is already considered done.</p>
<h3 id="differences-between-pytorch-and-tensorflow">Differences between PyTorch and TensorFlow</h3>
<p>Nvdiffrast can be used from PyTorch and from TensorFlow 1.x; the latter may change to TensorFlow 2.x if there is demand. These frameworks operate somewhat differently and that is reflected in the respective APIs. Simplifying a bit, in TensorFlow 1.x you construct a persistent graph out of persistent nodes, and run many batches of data through it. In PyTorch, there is no persistent graph or nodes, but a new, ephemeral graph is constructed for each batch of data and destroyed immediately afterwards. Therefore, there is also no persistent state for the operations. There is the <code>torch.nn.Module</code> abstraction for festooning operations with persistent state, but we do not use it.</p>
<p>As a consequence, things that would be part of persistent state of an nvdiffrast operation in TensorFlow must be stored by the user in PyTorch, and supplied to the operations as needed. In practice, this is a very small difference and amounts to just a couple of lines of code in most cases.</p>
......@@ -912,9 +918,7 @@ device.</td></tr></table><div class="methods">Methods, only available if context
<div class="apifunc"><h4><code>nvdiffrast.torch.rasterize(<em>glctx</em>, <em>pos</em>, <em>tri</em>, <em>resolution</em>, <em>ranges</em>=<span class="defarg">None</span>, <em>grad_db</em>=<span class="defarg">True</span>)</code>&nbsp;<span class="sym_function">Function</span></h4>
<p class="shortdesc">Rasterize triangles.</p><p class="longdesc">All input tensors must be contiguous and reside in GPU memory except for
the <code>ranges</code> tensor that, if specified, has to reside in CPU memory. The
output tensors will be contiguous and reside in GPU memory.</p><p class="longdesc">Note: For an unknown reason, on Windows the very first rasterization call using
a newly created OpenGL context may *sometimes* output a blank buffer. This is a
known bug and has never been observed to affect subsequent calls.</p><div class="arguments">Arguments:</div><table class="args"><tr class="arg"><td class="argname">glctx</td><td class="arg_short">OpenGL context of type <code>RasterizeGLContext</code>.</td></tr><tr class="arg"><td class="argname">pos</td><td class="arg_short">Vertex position tensor with dtype <code>torch.float32</code>. To enable range
output tensors will be contiguous and reside in GPU memory.</p><div class="arguments">Arguments:</div><table class="args"><tr class="arg"><td class="argname">glctx</td><td class="arg_short">OpenGL context of type <code>RasterizeGLContext</code>.</td></tr><tr class="arg"><td class="argname">pos</td><td class="arg_short">Vertex position tensor with dtype <code>torch.float32</code>. To enable range
mode, this tensor should have a 2D shape [num_vertices, 4]. To enable
instanced mode, use a 3D shape [minibatch_size, num_vertices, 4].</td></tr><tr class="arg"><td class="argname">tri</td><td class="arg_short">Triangle tensor with shape [num_triangles, 3] and dtype <code>torch.int32</code>.</td></tr><tr class="arg"><td class="argname">resolution</td><td class="arg_short">Output resolution as integer tuple (height, width).</td></tr><tr class="arg"><td class="argname">ranges</td><td class="arg_short">In range mode, tensor with shape [minibatch_size, 2] and dtype
<code>torch.int32</code>, specifying start indices and counts into <code>tri</code>.
......
......@@ -6,4 +6,4 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
__version__ = '0.2.7'
__version__ = '0.2.8'
......@@ -25,6 +25,7 @@ GLUTIL_EXT(void, glGetProgramInfoLog, GLuint program, GLsizei bufSize,
GLUTIL_EXT(void, glGetProgramiv, GLuint program, GLenum pname, GLint* param);
GLUTIL_EXT(void, glLinkProgram, GLuint program);
GLUTIL_EXT(void, glShaderSource, GLuint shader, GLsizei count, const char *const* string, const GLint* length);
GLUTIL_EXT(void, glUniform1f, GLint location, GLfloat v0);
GLUTIL_EXT(void, glUniform2f, GLint location, GLfloat v0, GLfloat v1);
GLUTIL_EXT(void, glUseProgram, GLuint program);
GLUTIL_EXT(void, glVertexAttribPointer, GLuint index, GLint size, GLenum type, GLboolean normalized, GLsizei stride, const void* pointer);
......
......@@ -44,13 +44,21 @@ struct GLDrawCmd
//------------------------------------------------------------------------
// GL helpers.
static void compileGLShader(NVDR_CTX_ARGS, GLuint* pShader, GLenum shaderType, const char* src)
static void compileGLShader(NVDR_CTX_ARGS, const RasterizeGLState& s, GLuint* pShader, GLenum shaderType, const char* src_buf)
{
const char* srcPtr = src;
int srcLength = strlen(src);
std::string src(src_buf);
// Set preprocessor directives.
int n = src.find('\n') + 1; // After first line containing #version directive.
if (s.enableZModify)
src.insert(n, "#define IF_ZMODIFY(x) x\n");
else
src.insert(n, "#define IF_ZMODIFY(x)\n");
const char *cstr = src.c_str();
*pShader = 0;
NVDR_CHECK_GL_ERROR(*pShader = glCreateShader(shaderType));
NVDR_CHECK_GL_ERROR(glShaderSource(*pShader, 1, &srcPtr, &srcLength));
NVDR_CHECK_GL_ERROR(glShaderSource(*pShader, 1, &cstr, 0));
NVDR_CHECK_GL_ERROR(glCompileShader(*pShader));
}
......@@ -103,11 +111,16 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
LOG(INFO) << "OpenGL version reported as " << vMajor << "." << vMinor;
NVDR_CHECK((vMajor == 4 && vMinor >= 4) || vMajor > 4, "OpenGL 4.4 or later is required");
// Enable depth modification workaround on A100 and later.
int capMajor = 0;
NVDR_CHECK_CUDA_ERROR(cudaDeviceGetAttribute(&capMajor, cudaDevAttrComputeCapabilityMajor, cudaDeviceIdx));
s.enableZModify = (capMajor >= 8);
// Number of output buffers.
int num_outputs = s.enableDB ? 2 : 1;
// Set up vertex shader.
compileGLShader(NVDR_CTX_PARAMS, &s.glVertexShader, GL_VERTEX_SHADER,
compileGLShader(NVDR_CTX_PARAMS, s, &s.glVertexShader, GL_VERTEX_SHADER,
"#version 330\n"
"#extension GL_ARB_shader_draw_parameters : enable\n"
STRINGIFY_SHADER_SOURCE(
......@@ -132,7 +145,7 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
// --> du/dX = d((u/w) / (1/w))/dX
// --> du/dX = [d(u/w)/dX - u*d(1/w)/dX] * w
// and we know both d(u/w)/dX and d(1/w)/dX are constant over triangle.
compileGLShader(NVDR_CTX_PARAMS, &s.glGeometryShader, GL_GEOMETRY_SHADER,
compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER,
"#version 430\n"
STRINGIFY_SHADER_SOURCE(
layout(triangles) in;
......@@ -188,23 +201,29 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
);
// Set up fragment shader.
compileGLShader(NVDR_CTX_PARAMS, &s.glFragmentShader, GL_FRAGMENT_SHADER,
"#version 330\n"
compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER,
"#version 430\n"
STRINGIFY_SHADER_SOURCE(
in vec4 var_uvzw;
in vec4 var_db;
layout(location = 0) out vec4 out_raster;
layout(location = 1) out vec4 out_db;
IF_ZMODIFY(
layout(location = 1) uniform float in_dummy;
in vec4 gl_FragCoord;
out float gl_FragDepth;
)
void main()
{
out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, float(gl_PrimitiveID + 1));
out_db = var_db * var_uvzw.w;
IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;)
}
)
);
// Set up fragment shader for depth peeling.
compileGLShader(NVDR_CTX_PARAMS, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER,
compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER,
"#version 430\n"
STRINGIFY_SHADER_SOURCE(
in vec4 var_uvzw;
......@@ -212,6 +231,11 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
layout(binding = 0) uniform sampler2DArray out_prev;
layout(location = 0) out vec4 out_raster;
layout(location = 1) out vec4 out_db;
IF_ZMODIFY(
layout(location = 1) uniform float in_dummy;
in vec4 gl_FragCoord;
out float gl_FragDepth;
)
void main()
{
vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0);
......@@ -220,6 +244,7 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
discard;
out_raster = vec4(var_uvzw.x, var_uvzw.y, depth_new, float(gl_PrimitiveID + 1));
out_db = var_db * var_uvzw.w;
IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;)
}
)
);
......@@ -227,7 +252,7 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
else
{
// Geometry shader without bary differential output.
compileGLShader(NVDR_CTX_PARAMS, &s.glGeometryShader, GL_GEOMETRY_SHADER,
compileGLShader(NVDR_CTX_PARAMS, s, &s.glGeometryShader, GL_GEOMETRY_SHADER,
"#version 330\n"
STRINGIFY_SHADER_SOURCE(
layout(triangles) in;
......@@ -248,25 +273,36 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
);
// Fragment shader without bary differential output.
compileGLShader(NVDR_CTX_PARAMS, &s.glFragmentShader, GL_FRAGMENT_SHADER,
"#version 330\n"
compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShader, GL_FRAGMENT_SHADER,
"#version 430\n"
STRINGIFY_SHADER_SOURCE(
in vec4 var_uvzw;
layout(location = 0) out vec4 out_raster;
IF_ZMODIFY(
layout(location = 1) uniform float in_dummy;
in vec4 gl_FragCoord;
out float gl_FragDepth;
)
void main()
{
out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, float(gl_PrimitiveID + 1));
IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;)
}
)
);
// Depth peeling variant of fragment shader.
compileGLShader(NVDR_CTX_PARAMS, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER,
compileGLShader(NVDR_CTX_PARAMS, s, &s.glFragmentShaderDP, GL_FRAGMENT_SHADER,
"#version 430\n"
STRINGIFY_SHADER_SOURCE(
in vec4 var_uvzw;
layout(binding = 0) uniform sampler2DArray out_prev;
layout(location = 0) out vec4 out_raster;
IF_ZMODIFY(
layout(location = 1) uniform float in_dummy;
in vec4 gl_FragCoord;
out float gl_FragDepth;
)
void main()
{
vec4 prev = texelFetch(out_prev, ivec3(gl_FragCoord.x, gl_FragCoord.y, gl_Layer), 0);
......@@ -274,6 +310,7 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
if (prev.w == 0 || depth_new <= prev.z)
discard;
out_raster = vec4(var_uvzw.x, var_uvzw.y, var_uvzw.z / var_uvzw.w, float(gl_PrimitiveID + 1));
IF_ZMODIFY(gl_FragDepth = gl_FragCoord.z + in_dummy;)
}
)
);
......@@ -327,8 +364,10 @@ void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceId
NVDR_CHECK_GL_ERROR(glGenTextures(1, &s.glPrevOutBuffer));
}
void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, int triCount, int width, int height, int depth)
bool rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, int triCount, int width, int height, int depth)
{
bool changes = false;
// Resize vertex buffer?
if (posCount > s.posCount)
{
......@@ -338,6 +377,7 @@ void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, in
LOG(INFO) << "Increasing position buffer size to " << s.posCount << " float32";
NVDR_CHECK_GL_ERROR(glBufferData(GL_ARRAY_BUFFER, s.posCount * sizeof(float), NULL, GL_DYNAMIC_DRAW));
NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaPosBuffer, s.glPosBuffer, cudaGraphicsRegisterFlagsWriteDiscard));
changes = true;
}
// Resize triangle buffer?
......@@ -349,6 +389,7 @@ void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, in
LOG(INFO) << "Increasing triangle buffer size to " << s.triCount << " int32";
NVDR_CHECK_GL_ERROR(glBufferData(GL_ELEMENT_ARRAY_BUFFER, s.triCount * sizeof(int32_t), NULL, GL_DYNAMIC_DRAW));
NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterBuffer(&s.cudaTriBuffer, s.glTriBuffer, cudaGraphicsRegisterFlagsWriteDiscard));
changes = true;
}
// Resize framebuffer?
......@@ -391,7 +432,11 @@ void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, in
// (Re-)register all GL buffers into Cuda.
for (int i=0; i < num_outputs; i++)
NVDR_CHECK_CUDA_ERROR(cudaGraphicsGLRegisterImage(&s.cudaColorBuffer[i], s.glColorBuffer[i], GL_TEXTURE_3D, cudaGraphicsRegisterFlagsReadOnly));
changes = true;
}
return changes;
}
void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx)
......@@ -477,6 +522,10 @@ void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, co
if (s.enableDB)
NVDR_CHECK_GL_ERROR(glUniform2f(0, 2.f / (float)width, 2.f / (float)height));
// Set the dummy uniform if depth modification workaround is active.
if (s.enableZModify)
NVDR_CHECK_GL_ERROR(glUniform1f(1, 0.f));
// Render the meshes.
if (depth == 1 && !rangesPtr)
{
......
......@@ -70,13 +70,14 @@ struct RasterizeGLState // Must be initializable by memset to zero.
cudaGraphicsResource_t cudaPosBuffer;
cudaGraphicsResource_t cudaTriBuffer;
int enableDB;
int enableZModify; // Modify depth in shader, workaround for a rasterization issue on A100.
};
//------------------------------------------------------------------------
// Shared C++ code prototypes.
void rasterizeInitGLContext(NVDR_CTX_ARGS, RasterizeGLState& s, int cudaDeviceIdx);
void rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, int triCount, int width, int height, int depth);
bool rasterizeResizeBuffers(NVDR_CTX_ARGS, RasterizeGLState& s, int posCount, int triCount, int width, int height, int depth);
void rasterizeRender(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, const float* posPtr, int posCount, int vtxPerInstance, const int32_t* triPtr, int triCount, const int32_t* rangesPtr, int width, int height, int depth, int peeling_idx);
void rasterizeCopyResults(NVDR_CTX_ARGS, RasterizeGLState& s, cudaStream_t stream, float** outputPtr, int width, int height, int depth);
void rasterizeReleaseBuffers(NVDR_CTX_ARGS, RasterizeGLState& s);
......
......@@ -203,10 +203,6 @@ def rasterize(glctx, pos, tri, resolution, ranges=None, grad_db=True):
the `ranges` tensor that, if specified, has to reside in CPU memory. The
output tensors will be contiguous and reside in GPU memory.
Note: For an unknown reason, on Windows the very first rasterization call using
a newly created OpenGL context may *sometimes* output a blank buffer. This is a
known bug and has never been observed to affect subsequent calls.
Args:
glctx: OpenGL context of type `RasterizeGLContext`.
pos: Vertex position tensor with dtype `torch.float32`. To enable range
......
......@@ -99,7 +99,14 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_fwd(RasterizeGLStateWrapper&
setGLContext(s.glctx);
// Resize all buffers.
rasterizeResizeBuffers(NVDR_CTX_PARAMS, s, posCount, triCount, width, height, depth);
if (rasterizeResizeBuffers(NVDR_CTX_PARAMS, s, posCount, triCount, width, height, depth))
{
#ifdef _WIN32
// Workaround for occasional blank first frame on Windows.
releaseGLContext();
setGLContext(s.glctx);
#endif
}
// Copy input data to GL and render.
const float* posPtr = pos.data_ptr<float>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment