"tests/spec_decode/__init__.py" did not exist on "0b98ba15c744f1dfb0ea4f2135e85ca23d572ae1"
Unverified Commit b0755523 authored by milesial's avatar milesial Committed by GitHub
Browse files

[Core] Reduce mm scheduler, get_num_embed overhead (#40143)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 993859ce
......@@ -26,11 +26,8 @@ def test_placeholder_range_get_num_embeds(is_embed, expected):
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
(torch.tensor([False, True, False, True, True]), [0, 1, 1, 2, 3]),
(torch.tensor([True, True, True]), [1, 2, 3]),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
......@@ -41,6 +38,6 @@ def test_placeholder_range_embeds_cumsum(is_embed, expected):
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
assert pr.embeds_cumsum == expected
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
......@@ -145,14 +145,15 @@ class PlaceholderRange:
"""
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
return None if self.is_embed is None else self.is_embed.cumsum(dim=0)
def embeds_cumsum(self) -> list[int] | None:
# python list so python indexing avoids torch C++ overhead/conversions/deallocs
return None if self.is_embed is None else self.is_embed.cumsum(dim=0).tolist()
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length
return int(self.embeds_cumsum[-1])
return self.embeds_cumsum[-1] if self.embeds_cumsum else 0
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
......@@ -170,10 +171,8 @@ class PlaceholderRange:
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
embeds_start_idx = self.embeds_cumsum[start_idx - 1] if start_idx > 0 else 0
embeds_end_idx = self.embeds_cumsum[end_idx - 1] if end_idx > 0 else 0
return embeds_start_idx, embeds_end_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment