Unverified Commit 8421c146 authored by Ivan Skorokhodov's avatar Ivan Skorokhodov Committed by GitHub
Browse files

Use parameters + buffers when deciding upscale_dtype (#9882)

Sometimes, the decoder might lack parameters and only buffers (e.g., this happens when we manually need to convert all the parameters to buffers — e.g. to avoid packing fp16 and fp32 parameters with FSDP)
parent cfdeebd4
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Dict, Optional, Tuple, Union
import torch
......@@ -94,7 +95,7 @@ class TemporalDecoder(nn.Module):
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment