Unverified Commit e0392e42 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: simplify the implementation of the async offloading and support ComfyUI offloading

parent eb901251
......@@ -327,11 +327,9 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
if self.offload:
self.offload_manager.initialize()
compute_stream = self.offload_manager.compute_stream
else:
compute_stream = torch.cuda.current_stream()
if self.offload:
self.offload_manager.initialize(compute_stream)
for block_idx, block in enumerate(self.transformer_blocks):
with torch.cuda.stream(compute_stream):
if self.offload:
......@@ -345,7 +343,7 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
joint_attention_kwargs=attention_kwargs,
)
if self.offload:
self.offload_manager.step()
self.offload_manager.step(compute_stream)
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
......
......@@ -46,7 +46,6 @@ class CPUOffloadManager:
assert self.num_blocks_on_gpu > 0
# Two streams: one for compute, one for memory operations, will be initialized in set_device
self.compute_stream = None
self.memory_stream = None
self.compute_done = torch.cuda.Event(blocking=False)
......@@ -68,7 +67,6 @@ class CPUOffloadManager:
if self.device == device and not force:
return
self.device = device
self.compute_stream = torch.cuda.Stream(device=device)
self.memory_stream = torch.cuda.Stream(device=device)
for block in self.buffer_blocks:
block.to(device)
......@@ -97,10 +95,12 @@ class CPUOffloadManager:
block = self.blocks[block_idx]
copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
def step(self, next_stream: torch.cuda.Stream | None = None):
def step(self, compute_stream: torch.cuda.Stream | None = None):
"""Move to the next block, triggering preload operations."""
if compute_stream is None:
compute_stream = torch.cuda.current_stream()
next_compute_done = torch.cuda.Event()
next_compute_done.record(self.compute_stream)
next_compute_done.record(compute_stream)
with torch.cuda.stream(self.memory_stream):
self.memory_stream.wait_event(self.compute_done)
self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
......@@ -111,13 +111,10 @@ class CPUOffloadManager:
self.current_block_idx += 1
if self.current_block_idx < len(self.blocks):
# get ready for the next compute
self.compute_stream.wait_event(self.memory_done)
compute_stream.wait_event(self.memory_done)
else:
# ready to finish
if next_stream is None:
torch.cuda.current_stream().wait_event(self.compute_done)
else:
next_stream.wait_event(self.compute_done)
compute_stream.wait_event(self.compute_done)
self.current_block_idx = 0
self.forward_counter += 1
if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment