"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "5c1063df7fb6850410a35ddfb92cd6efb818fa6e"
Unverified Commit 3b04cdc8 authored by shinetzh's avatar shinetzh Committed by GitHub
Browse files

fix loop bug in SlicedAttnProcessor (#8836)



* fix loop bug in SlicedAttnProcessor


---------
Co-authored-by: default avatarneoshang <neoshang@tencent.com>
parent c009c203
...@@ -2190,7 +2190,7 @@ class SlicedAttnProcessor: ...@@ -2190,7 +2190,7 @@ class SlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
) )
for i in range(batch_size_attention // self.slice_size): for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size end_idx = (i + 1) * self.slice_size
...@@ -2287,7 +2287,7 @@ class SlicedAttnAddedKVProcessor: ...@@ -2287,7 +2287,7 @@ class SlicedAttnAddedKVProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
) )
for i in range(batch_size_attention // self.slice_size): for i in range((batch_size_attention - 1) // self.slice_size + 1):
start_idx = i * self.slice_size start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size end_idx = (i + 1) * self.slice_size
......
...@@ -1351,14 +1351,24 @@ class PipelineTesterMixin: ...@@ -1351,14 +1351,24 @@ class PipelineTesterMixin:
pipe.enable_attention_slicing(slice_size=1) pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device) inputs = self.get_dummy_inputs(generator_device)
output_with_slicing = pipe(**inputs)[0] output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference: if test_max_difference:
max_diff = np.abs(to_np(output_with_slicing) - to_np(output_without_slicing)).max() max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
if test_mean_pixel_difference: if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0])) assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment