Commit cb83f2f8 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Update transformer_infer.py (#249)

parent 2a9a64d0
...@@ -77,6 +77,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -77,6 +77,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.weights_stream_mgr.prefetch_weights_from_disk(blocks) self.weights_stream_mgr.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx) block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda() block.to_cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment