Commit 5b669ae9 authored by gushiqiao's avatar gushiqiao
Browse files

Fix docs

parent bae3d352
assets/figs/offload/fig5_zh.png

23.6 KB | W: | H:

assets/figs/offload/fig5_zh.png

52.3 KB | W: | H:

assets/figs/offload/fig5_zh.png
assets/figs/offload/fig5_zh.png
assets/figs/offload/fig5_zh.png
assets/figs/offload/fig5_zh.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -23,7 +23,7 @@ except ImportError: ...@@ -23,7 +23,7 @@ except ImportError:
logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
if torch.cuda.get_device_capability(0)[0] <= 8 and torch.cuda.get_device_capability(0)[1] <= 9: if torch.cuda.get_device_capability(0) == (8, 9):
try: try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError: except ImportError:
......
...@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -29,6 +29,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.mask_map = None self.mask_map = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment