Unverified Commit 0b1e04f0 authored by Adarsh Shirawalmath's avatar Adarsh Shirawalmath Committed by GitHub
Browse files

[VLM] Improving multimodal tensor hash kernel (#9008)

parent c1c7dc45
......@@ -17,57 +17,173 @@ import torch
import triton
import triton.language as tl
FMIX32_C1 = 0x85EBCA6B
FMIX32_C2 = 0xC2B2AE35
POS_C1 = 0x27D4EB2D
POS_C2 = 0x165667B1
@triton.jit
def _rotl32(x, r: tl.constexpr):
return (x << r) | (x >> (32 - r))
@triton.jit
def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
c1 = tl.full((), C1, tl.uint32)
c2 = tl.full((), C2, tl.uint32)
x ^= x >> 16
x = x * c1
x ^= x >> 13
x = x * c2
x ^= x >> 16
return x
@triton.jit
def hash_kernel(
input_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
PRIME: tl.constexpr,
XCONST: tl.constexpr,
def hash_tiles32_kernel_blocked(
in_ptr,
out_ptr,
n_u32,
seed1,
seed2,
FM_C1: tl.constexpr,
FM_C2: tl.constexpr,
POS_A: tl.constexpr,
POS_B: tl.constexpr,
TILE: tl.constexpr,
BLOCK: tl.constexpr,
USE_CG: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
base = pid * TILE
s1 = tl.full((), seed1, tl.uint32)
s2 = tl.full((), seed2, tl.uint32)
posA = tl.full((), POS_A, tl.uint32)
posB = tl.full((), POS_B, tl.uint32)
h1 = tl.zeros((), dtype=tl.uint32)
h2 = tl.zeros((), dtype=tl.uint32)
for off in tl.static_range(0, TILE, BLOCK):
idx = base + off + tl.arange(0, BLOCK)
m = idx < n_u32
data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64)
mixed = data ^ (offsets.to(tl.int64) + XCONST)
hash_val = mixed * PRIME
hash_val = hash_val ^ (hash_val >> 16)
hash_val = hash_val * (PRIME ^ XCONST)
hash_val = hash_val ^ (hash_val >> 13)
if USE_CG:
v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
else:
v = tl.load(in_ptr + idx, mask=m, other=0)
v = v.to(tl.uint32)
iu = idx.to(tl.uint32)
p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
zero32 = tl.zeros_like(k1)
k1 = tl.where(m, k1, zero32)
k2 = tl.where(m, k2, zero32)
h1 += tl.sum(k1, axis=0).to(tl.uint32)
h2 += tl.sum(k2, axis=0).to(tl.uint32)
nbytes = tl.full((), n_u32 * 4, tl.uint32)
h1 ^= nbytes
h2 ^= nbytes
h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
h2 = (
_fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
if False
else _fmix32(h2, C1=FM_C1, C2=FM_C2)
)
out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
tl.store(out_ptr + pid, out)
@triton.jit
def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
pid = tl.program_id(axis=0)
start = pid * CHUNK
h = tl.zeros((), dtype=tl.uint64)
for i in tl.static_range(0, CHUNK):
idx = start + i
m = idx < n_elems
v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
h += v
tl.store(out_ptr + pid, h)
tl.store(output_ptr + offsets, hash_val, mask=mask)
def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
assert t.is_cuda, "Use .cuda() first"
tb = t.contiguous().view(torch.uint8)
nbytes = tb.numel()
pad = (4 - (nbytes & 3)) & 3
if pad:
tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
tb_p[:nbytes].copy_(tb)
tb_p[nbytes:].zero_()
tb = tb_p
return tb.view(torch.uint32)
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
def _final_splitmix64(x: int) -> int:
mask = (1 << 64) - 1
x &= mask
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & mask
x ^= x >> 27
x = (x * 0x94D049BB133111EB) & mask
x ^= x >> 31
return x
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
assert tensor.is_cuda
tensor = tensor.contiguous().view(torch.int32)
n = tensor.numel()
BLOCK_SIZE = 1024
grid = (triton.cdiv(n, BLOCK_SIZE),)
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
@torch.inference_mode()
def gpu_tensor_hash(
tensor: torch.Tensor,
*,
seed: int = 0x243F6A88,
tile_words: int = 8192,
block_words: int = 256,
reduce_chunk: int = 1024,
num_warps: int = 4,
num_stages: int = 4,
use_cg: bool = True,
) -> int:
assert tensor.is_cuda, "Use .cuda() first"
u32 = _as_uint32_words(tensor)
n = u32.numel()
if n == 0:
return 0
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
with torch.cuda.device(tensor.device):
hash_kernel[grid](
tensor,
intermediate_hashes,
n,
BLOCK_SIZE=BLOCK_SIZE,
PRIME=PRIME_1,
XCONST=PRIME_2,
)
grid1 = (triton.cdiv(n, tile_words),)
partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
hash_tiles32_kernel_blocked[grid1](
u32,
partials,
n,
seed1=seed & 0xFFFFFFFF,
seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
FM_C1=FMIX32_C1,
FM_C2=FMIX32_C2,
POS_A=POS_C1,
POS_B=POS_C2,
TILE=tile_words,
BLOCK=block_words,
USE_CG=use_cg,
num_warps=num_warps,
num_stages=num_stages,
)
# TODO: threads can't be synced on triton kernel
final_hash = intermediate_hashes.sum().item()
cur = partials
while cur.numel() > 1:
n_elems = cur.numel()
grid2 = (triton.cdiv(n_elems, reduce_chunk),)
nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
cur = nxt
return final_hash
return _final_splitmix64(int(cur.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment