Unverified Commit e0392e42 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: simplify the implementation of the async offloading and support ComfyUI offloading

parent eb901251
...@@ -327,11 +327,9 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -327,11 +327,9 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
compute_stream = torch.cuda.current_stream()
if self.offload: if self.offload:
self.offload_manager.initialize() self.offload_manager.initialize(compute_stream)
compute_stream = self.offload_manager.compute_stream
else:
compute_stream = torch.cuda.current_stream()
for block_idx, block in enumerate(self.transformer_blocks): for block_idx, block in enumerate(self.transformer_blocks):
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if self.offload: if self.offload:
...@@ -345,7 +343,7 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -345,7 +343,7 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
joint_attention_kwargs=attention_kwargs, joint_attention_kwargs=attention_kwargs,
) )
if self.offload: if self.offload:
self.offload_manager.step() self.offload_manager.step(compute_stream)
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states) output = self.proj_out(hidden_states)
......
...@@ -46,7 +46,6 @@ class CPUOffloadManager: ...@@ -46,7 +46,6 @@ class CPUOffloadManager:
assert self.num_blocks_on_gpu > 0 assert self.num_blocks_on_gpu > 0
# Two streams: one for compute, one for memory operations, will be initialized in set_device # Two streams: one for compute, one for memory operations, will be initialized in set_device
self.compute_stream = None
self.memory_stream = None self.memory_stream = None
self.compute_done = torch.cuda.Event(blocking=False) self.compute_done = torch.cuda.Event(blocking=False)
...@@ -68,7 +67,6 @@ class CPUOffloadManager: ...@@ -68,7 +67,6 @@ class CPUOffloadManager:
if self.device == device and not force: if self.device == device and not force:
return return
self.device = device self.device = device
self.compute_stream = torch.cuda.Stream(device=device)
self.memory_stream = torch.cuda.Stream(device=device) self.memory_stream = torch.cuda.Stream(device=device)
for block in self.buffer_blocks: for block in self.buffer_blocks:
block.to(device) block.to(device)
...@@ -97,10 +95,12 @@ class CPUOffloadManager: ...@@ -97,10 +95,12 @@ class CPUOffloadManager:
block = self.blocks[block_idx] block = self.blocks[block_idx]
copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking) copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
def step(self, next_stream: torch.cuda.Stream | None = None): def step(self, compute_stream: torch.cuda.Stream | None = None):
"""Move to the next block, triggering preload operations.""" """Move to the next block, triggering preload operations."""
if compute_stream is None:
compute_stream = torch.cuda.current_stream()
next_compute_done = torch.cuda.Event() next_compute_done = torch.cuda.Event()
next_compute_done.record(self.compute_stream) next_compute_done.record(compute_stream)
with torch.cuda.stream(self.memory_stream): with torch.cuda.stream(self.memory_stream):
self.memory_stream.wait_event(self.compute_done) self.memory_stream.wait_event(self.compute_done)
self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block self.load_block(self.current_block_idx + 1) # if the current block is the last block, load the first block
...@@ -111,13 +111,10 @@ class CPUOffloadManager: ...@@ -111,13 +111,10 @@ class CPUOffloadManager:
self.current_block_idx += 1 self.current_block_idx += 1
if self.current_block_idx < len(self.blocks): if self.current_block_idx < len(self.blocks):
# get ready for the next compute # get ready for the next compute
self.compute_stream.wait_event(self.memory_done) compute_stream.wait_event(self.memory_done)
else: else:
# ready to finish # ready to finish
if next_stream is None: compute_stream.wait_event(self.compute_done)
torch.cuda.current_stream().wait_event(self.compute_done)
else:
next_stream.wait_event(self.compute_done)
self.current_block_idx = 0 self.current_block_idx = 0
self.forward_counter += 1 self.forward_counter += 1
if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0: if self.empty_cache_freq > 0 and self.forward_counter % self.empty_cache_freq == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment