Unverified Commit 3b04cdc8 authored by shinetzh's avatar shinetzh Committed by GitHub
Browse files

fix loop bug in SlicedAttnProcessor (#8836)



* fix loop bug in SlicedAttnProcessor


---------
Co-authored-by: default avatarneoshang <neoshang@tencent.com>
parent c009c203
......@@ -2190,7 +2190,7 @@ class SlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
......@@ -2287,7 +2287,7 @@ class SlicedAttnAddedKVProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
......
......@@ -1351,14 +1351,24 @@ class PipelineTesterMixin:
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing = pipe(**inputs)[0]
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment