Unverified Commit 73b35cca authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] Improve hash collision avoidance in prefix caching (#12621)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 5095e966
...@@ -65,8 +65,8 @@ class TestPrefixCachingBlock: ...@@ -65,8 +65,8 @@ class TestPrefixCachingBlock:
previous_block = MagicMock(spec=PrefixCachingBlock) previous_block = MagicMock(spec=PrefixCachingBlock)
prev_block_hash = random.randint(0, 1000) prev_block_hash = random.randint(0, 1000)
previous_block.content_hash = (prev_block_hash previous_block.content_hash = (prev_block_hash if prev_block_has_hash
if prev_block_has_hash else None) else hash('None'))
num_to_fill = block_size if is_curr_block_full else random.randint( num_to_fill = block_size if is_curr_block_full else random.randint(
0, block_size - 1) 0, block_size - 1)
......
...@@ -65,6 +65,15 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -65,6 +65,15 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1. from 0 to num_blocks - 1.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
# Implements Block.Factory.
def __init__( def __init__(
self, self,
num_blocks: int, num_blocks: int,
...@@ -122,7 +131,6 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -122,7 +131,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self.metric_data = CacheMetricData() self.metric_data = CacheMetricData()
# Implements Block.Factory.
def _create_block( def _create_block(
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
...@@ -737,6 +745,14 @@ class PrefixCachingBlock(Block): ...@@ -737,6 +745,14 @@ class PrefixCachingBlock(Block):
such as adapters that influence the block, apart from the token_ids. such as adapters that influence the block, apart from the token_ids.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__( def __init__(
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
...@@ -891,13 +907,13 @@ class PrefixCachingBlock(Block): ...@@ -891,13 +907,13 @@ class PrefixCachingBlock(Block):
is_first_block = self._prev_block is None is_first_block = self._prev_block is None
prev_block_hash = ( prev_block_hash = (
None if is_first_block else self._none_hash if is_first_block else
self._prev_block.content_hash # type: ignore self._prev_block.content_hash # type: ignore
) )
# Previous block exists but does not yet have a hash. # Previous block exists but does not yet have a hash.
# Return no hash in this case. # Return no hash in this case.
if prev_block_hash is None and not is_first_block: if prev_block_hash == self._none_hash and not is_first_block:
return None return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
...@@ -907,8 +923,9 @@ class PrefixCachingBlock(Block): ...@@ -907,8 +923,9 @@ class PrefixCachingBlock(Block):
extra_hash=self._extra_hash) extra_hash=self._extra_hash)
return self._cached_content_hash return self._cached_content_hash
@staticmethod @classmethod
def hash_block_tokens(is_first_block: bool, def hash_block_tokens(cls,
is_first_block: bool,
prev_block_hash: Optional[int], prev_block_hash: Optional[int],
cur_block_token_ids: List[int], cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int: extra_hash: Optional[int] = None) -> int:
...@@ -929,7 +946,8 @@ class PrefixCachingBlock(Block): ...@@ -929,7 +946,8 @@ class PrefixCachingBlock(Block):
Returns: Returns:
- int: The computed hash value for the block. - int: The computed hash value for the block.
""" """
assert (prev_block_hash is None) == is_first_block if is_first_block and prev_block_hash is None:
prev_block_hash = cls._none_hash
return hash((is_first_block, prev_block_hash, *cur_block_token_ids, return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash)) extra_hash))
...@@ -949,6 +967,14 @@ class ComputedBlocksTracker: ...@@ -949,6 +967,14 @@ class ComputedBlocksTracker:
cached block hashes in the allocator. cached block hashes in the allocator.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__( def __init__(
self, self,
allocator: DeviceAwareBlockAllocator, allocator: DeviceAwareBlockAllocator,
...@@ -994,7 +1020,7 @@ class ComputedBlocksTracker: ...@@ -994,7 +1020,7 @@ class ComputedBlocksTracker:
# We need to know the hash of the previous block to compute the hash of # We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across # the current block so that blocks could be uniquely identified across
# sequences of prefixes. # sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1]) block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks # Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_computed_blocks): for i in range(cur_num_blocks_recorded, num_computed_blocks):
...@@ -1009,7 +1035,7 @@ class ComputedBlocksTracker: ...@@ -1009,7 +1035,7 @@ class ComputedBlocksTracker:
# This has to be kept in sync with the allocator's hash # This has to be kept in sync with the allocator's hash
# calculation. # calculation.
block_hash = PrefixCachingBlock.hash_block_tokens( block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None, is_first_block=prev_block_hash == self._none_hash,
prev_block_hash=prev_block_hash, prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids, cur_block_token_ids=block_token_ids,
extra_hash=extra_hash, extra_hash=extra_hash,
......
...@@ -263,6 +263,15 @@ def hash_block_tokens( ...@@ -263,6 +263,15 @@ def hash_block_tokens(
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
curr_block_token_ids_tuple = tuple(curr_block_token_ids) curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType( return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)), hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment