Unverified Commit b0755523 authored by milesial's avatar milesial Committed by GitHub
Browse files

[Core] Reduce mm scheduler, get_num_embed overhead (#40143)


Signed-off-by: default avatarmilesial <milesial@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 993859ce
...@@ -26,11 +26,8 @@ def test_placeholder_range_get_num_embeds(is_embed, expected): ...@@ -26,11 +26,8 @@ def test_placeholder_range_get_num_embeds(is_embed, expected):
"is_embed,expected", "is_embed,expected",
[ [
(None, None), (None, None),
( (torch.tensor([False, True, False, True, True]), [0, 1, 1, 2, 3]),
torch.tensor([False, True, False, True, True]), (torch.tensor([True, True, True]), [1, 2, 3]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
], ],
) )
def test_placeholder_range_embeds_cumsum(is_embed, expected): def test_placeholder_range_embeds_cumsum(is_embed, expected):
...@@ -41,6 +38,6 @@ def test_placeholder_range_embeds_cumsum(is_embed, expected): ...@@ -41,6 +38,6 @@ def test_placeholder_range_embeds_cumsum(is_embed, expected):
assert pr.embeds_cumsum is None assert pr.embeds_cumsum is None
return return
assert torch.equal(pr.embeds_cumsum, expected) assert pr.embeds_cumsum == expected
# cached_property should return the same object on repeated access # cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum assert pr.embeds_cumsum is pr.embeds_cumsum
...@@ -145,14 +145,15 @@ class PlaceholderRange: ...@@ -145,14 +145,15 @@ class PlaceholderRange:
""" """
@cached_property @cached_property
def embeds_cumsum(self) -> torch.Tensor | None: def embeds_cumsum(self) -> list[int] | None:
return None if self.is_embed is None else self.is_embed.cumsum(dim=0) # python list so python indexing avoids torch C++ overhead/conversions/deallocs
return None if self.is_embed is None else self.is_embed.cumsum(dim=0).tolist()
def get_num_embeds(self) -> int: def get_num_embeds(self) -> int:
if self.embeds_cumsum is None: if self.embeds_cumsum is None:
return self.length return self.length
return int(self.embeds_cumsum[-1]) return self.embeds_cumsum[-1] if self.embeds_cumsum else 0
def get_embeds_indices_in_range( def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int self, start_idx: int, end_idx: int
...@@ -170,10 +171,8 @@ class PlaceholderRange: ...@@ -170,10 +171,8 @@ class PlaceholderRange:
if self.embeds_cumsum is None: if self.embeds_cumsum is None:
return start_idx, end_idx return start_idx, end_idx
embeds_start_idx = ( embeds_start_idx = self.embeds_cumsum[start_idx - 1] if start_idx > 0 else 0
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0 embeds_end_idx = self.embeds_cumsum[end_idx - 1] if end_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx return embeds_start_idx, embeds_end_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment