Unverified Commit 626ccb7d authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: tensor hash kernel (#5974)

parent 72bfb0ba
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logits processing."""
import torch
import triton
import triton.language as tl
@triton.jit
def hash_kernel(
input_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
PRIME: tl.constexpr,
XCONST: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
data = tl.load(input_ptr + offsets, mask=mask, other=0)
mixed = data ^ (offsets + XCONST)
hash_val = mixed * PRIME
hash_val = hash_val ^ (hash_val >> 16)
hash_val = hash_val * (PRIME ^ XCONST)
hash_val = hash_val ^ (hash_val >> 13)
tl.store(output_ptr + offsets, hash_val, mask=mask)
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
assert tensor.is_cuda
tensor = tensor.contiguous().view(torch.int32)
n = tensor.numel()
BLOCK_SIZE = 1024
grid = (triton.cdiv(n, BLOCK_SIZE),)
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
hash_kernel[grid](
tensor,
intermediate_hashes,
n,
BLOCK_SIZE=BLOCK_SIZE,
PRIME=PRIME_1,
XCONST=PRIME_2,
)
# TODO: threads can't be synced on triton kernel
final_hash = intermediate_hashes.sum().item()
return final_hash
......@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
......@@ -222,7 +223,8 @@ class MultimodalDataItem:
for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment