Commit 735ac4cf authored by comfyanonymous's avatar comfyanonymous
Browse files

Remove pytorch_lightning dependency.

parent cb180b99
import pickle
load = pickle.load
class Empty:
pass
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
#TODO: safe unpickle
if module.startswith("pytorch_lightning"):
return Empty
return super().find_class(module, name)
import torch import torch
import math import math
import struct import struct
import comfy.checkpoint_pickle
def load_torch_file(ckpt, safe_load=False): def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"): if ckpt.lower().endswith(".safetensors"):
...@@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False): ...@@ -14,7 +15,7 @@ def load_torch_file(ckpt, safe_load=False):
if safe_load: if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else: else:
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment