You need to sign in or sign up before continuing.
Unverified Commit 8421c146 authored by Ivan Skorokhodov's avatar Ivan Skorokhodov Committed by GitHub
Browse files

Use parameters + buffers when deciding upscale_dtype (#9882)

Sometimes, the decoder might lack parameters and only buffers (e.g., this happens when we manually need to convert all the parameters to buffers — e.g. to avoid packing fp16 and fp32 parameters with FSDP)
parent cfdeebd4
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
...@@ -94,7 +95,7 @@ class TemporalDecoder(nn.Module): ...@@ -94,7 +95,7 @@ class TemporalDecoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment