Commit ca82ade7 authored by comfyanonymous's avatar comfyanonymous
Browse files

Use .itemsize to get dtype size for fp8.

parent 31b0f6f3
......@@ -430,6 +430,13 @@ def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size
def unet_offload_device():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment