Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0b1e04f0
"torchvision/models/vscode:/vscode.git/clone" did not exist on "e722e9c7d128b0d8f12cdf956469ea7de18f3821"
Unverified
Commit
0b1e04f0
authored
Aug 15, 2025
by
Adarsh Shirawalmath
Committed by
GitHub
Aug 14, 2025
Browse files
[VLM] Improving multimodal tensor hash kernel (#9008)
parent
c1c7dc45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
156 additions
and
40 deletions
+156
-40
python/sglang/srt/layers/multimodal.py
python/sglang/srt/layers/multimodal.py
+156
-40
No files found.
python/sglang/srt/layers/multimodal.py
View file @
0b1e04f0
...
...
@@ -17,57 +17,173 @@ import torch
import
triton
import
triton.language
as
tl
FMIX32_C1
=
0x85EBCA6B
FMIX32_C2
=
0xC2B2AE35
POS_C1
=
0x27D4EB2D
POS_C2
=
0x165667B1
@
triton
.
jit
def
_rotl32
(
x
,
r
:
tl
.
constexpr
):
return
(
x
<<
r
)
|
(
x
>>
(
32
-
r
))
@
triton
.
jit
def
_fmix32
(
x
,
C1
:
tl
.
constexpr
,
C2
:
tl
.
constexpr
):
c1
=
tl
.
full
((),
C1
,
tl
.
uint32
)
c2
=
tl
.
full
((),
C2
,
tl
.
uint32
)
x
^=
x
>>
16
x
=
x
*
c1
x
^=
x
>>
13
x
=
x
*
c2
x
^=
x
>>
16
return
x
@
triton
.
jit
def
hash_kernel
(
input_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
PRIME
:
tl
.
constexpr
,
XCONST
:
tl
.
constexpr
,
def
hash_tiles32_kernel_blocked
(
in_ptr
,
out_ptr
,
n_u32
,
seed1
,
seed2
,
FM_C1
:
tl
.
constexpr
,
FM_C2
:
tl
.
constexpr
,
POS_A
:
tl
.
constexpr
,
POS_B
:
tl
.
constexpr
,
TILE
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
USE_CG
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
base
=
pid
*
TILE
s1
=
tl
.
full
((),
seed1
,
tl
.
uint32
)
s2
=
tl
.
full
((),
seed2
,
tl
.
uint32
)
posA
=
tl
.
full
((),
POS_A
,
tl
.
uint32
)
posB
=
tl
.
full
((),
POS_B
,
tl
.
uint32
)
h1
=
tl
.
zeros
((),
dtype
=
tl
.
uint32
)
h2
=
tl
.
zeros
((),
dtype
=
tl
.
uint32
)
for
off
in
tl
.
static_range
(
0
,
TILE
,
BLOCK
):
idx
=
base
+
off
+
tl
.
arange
(
0
,
BLOCK
)
m
=
idx
<
n_u32
data
=
tl
.
load
(
input_ptr
+
offsets
,
mask
=
mask
,
other
=
0
).
to
(
tl
.
int64
)
mixed
=
data
^
(
offsets
.
to
(
tl
.
int64
)
+
XCONST
)
hash_val
=
mixed
*
PRIME
hash_val
=
hash_val
^
(
hash_val
>>
16
)
hash_val
=
hash_val
*
(
PRIME
^
XCONST
)
hash_val
=
hash_val
^
(
hash_val
>>
13
)
if
USE_CG
:
v
=
tl
.
load
(
in_ptr
+
idx
,
mask
=
m
,
other
=
0
,
cache_modifier
=
".cg"
)
else
:
v
=
tl
.
load
(
in_ptr
+
idx
,
mask
=
m
,
other
=
0
)
v
=
v
.
to
(
tl
.
uint32
)
iu
=
idx
.
to
(
tl
.
uint32
)
p1
=
(
iu
*
posA
+
s1
)
^
_rotl32
(
iu
,
15
)
p2
=
(
iu
*
posB
+
s2
)
^
_rotl32
(
iu
,
13
)
k1
=
_fmix32
(
v
^
p1
,
C1
=
FM_C1
,
C2
=
FM_C2
)
k2
=
_fmix32
(
v
^
p2
,
C1
=
FM_C1
,
C2
=
FM_C2
)
zero32
=
tl
.
zeros_like
(
k1
)
k1
=
tl
.
where
(
m
,
k1
,
zero32
)
k2
=
tl
.
where
(
m
,
k2
,
zero32
)
h1
+=
tl
.
sum
(
k1
,
axis
=
0
).
to
(
tl
.
uint32
)
h2
+=
tl
.
sum
(
k2
,
axis
=
0
).
to
(
tl
.
uint32
)
nbytes
=
tl
.
full
((),
n_u32
*
4
,
tl
.
uint32
)
h1
^=
nbytes
h2
^=
nbytes
h1
=
_fmix32
(
h1
,
C1
=
FM_C1
,
C2
=
FM_C2
)
h2
=
(
_fmix32
(
h2
,
C1
=
FMIX32_C1
,
C2
=
FMIX32_C2
)
if
False
else
_fmix32
(
h2
,
C1
=
FM_C1
,
C2
=
FM_C2
)
)
out
=
(
h1
.
to
(
tl
.
uint64
)
<<
32
)
|
h2
.
to
(
tl
.
uint64
)
tl
.
store
(
out_ptr
+
pid
,
out
)
@
triton
.
jit
def
add_tree_reduce_u64_kernel
(
in_ptr
,
out_ptr
,
n_elems
,
CHUNK
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
start
=
pid
*
CHUNK
h
=
tl
.
zeros
((),
dtype
=
tl
.
uint64
)
for
i
in
tl
.
static_range
(
0
,
CHUNK
):
idx
=
start
+
i
m
=
idx
<
n_elems
v
=
tl
.
load
(
in_ptr
+
idx
,
mask
=
m
,
other
=
0
).
to
(
tl
.
uint64
)
h
+=
v
tl
.
store
(
out_ptr
+
pid
,
h
)
tl
.
store
(
output_ptr
+
offsets
,
hash_val
,
mask
=
mask
)
def
_as_uint32_words
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
t
.
is_cuda
,
"Use .cuda() first"
tb
=
t
.
contiguous
().
view
(
torch
.
uint8
)
nbytes
=
tb
.
numel
()
pad
=
(
4
-
(
nbytes
&
3
))
&
3
if
pad
:
tb_p
=
torch
.
empty
(
nbytes
+
pad
,
dtype
=
torch
.
uint8
,
device
=
tb
.
device
)
tb_p
[:
nbytes
].
copy_
(
tb
)
tb_p
[
nbytes
:].
zero_
()
tb
=
tb_p
return
tb
.
view
(
torch
.
uint32
)
PRIME_1
=
-
(
11400714785074694791
^
0xFFFFFFFFFFFFFFFF
)
-
1
PRIME_2
=
-
(
14029467366897019727
^
0xFFFFFFFFFFFFFFFF
)
-
1
def
_final_splitmix64
(
x
:
int
)
->
int
:
mask
=
(
1
<<
64
)
-
1
x
&=
mask
x
^=
x
>>
30
x
=
(
x
*
0xBF58476D1CE4E5B9
)
&
mask
x
^=
x
>>
27
x
=
(
x
*
0x94D049BB133111EB
)
&
mask
x
^=
x
>>
31
return
x
def
gpu_tensor_hash
(
tensor
:
torch
.
Tensor
)
->
int
:
assert
tensor
.
is_cuda
tensor
=
tensor
.
contiguous
().
view
(
torch
.
int32
)
n
=
tensor
.
numel
()
BLOCK_SIZE
=
1024
grid
=
(
triton
.
cdiv
(
n
,
BLOCK_SIZE
),)
intermediate_hashes
=
torch
.
empty
(
n
,
dtype
=
torch
.
int64
,
device
=
tensor
.
device
)
@
torch
.
inference_mode
()
def
gpu_tensor_hash
(
tensor
:
torch
.
Tensor
,
*
,
seed
:
int
=
0x243F6A88
,
tile_words
:
int
=
8192
,
block_words
:
int
=
256
,
reduce_chunk
:
int
=
1024
,
num_warps
:
int
=
4
,
num_stages
:
int
=
4
,
use_cg
:
bool
=
True
,
)
->
int
:
assert
tensor
.
is_cuda
,
"Use .cuda() first"
u32
=
_as_uint32_words
(
tensor
)
n
=
u32
.
numel
()
if
n
==
0
:
return
0
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
with
torch
.
cuda
.
device
(
tensor
.
device
):
hash_kernel
[
grid
](
tensor
,
intermediate_hashes
,
n
,
BLOCK_SIZE
=
BLOCK_SIZE
,
PRIME
=
PRIME_1
,
XCONST
=
PRIME_2
,
)
grid1
=
(
triton
.
cdiv
(
n
,
tile_words
),)
partials
=
torch
.
empty
(
grid1
[
0
],
dtype
=
torch
.
uint64
,
device
=
u32
.
device
)
hash_tiles32_kernel_blocked
[
grid1
](
u32
,
partials
,
n
,
seed1
=
seed
&
0xFFFFFFFF
,
seed2
=
((
seed
*
0x9E3779B1
)
^
0xDEADBEEF
)
&
0xFFFFFFFF
,
FM_C1
=
FMIX32_C1
,
FM_C2
=
FMIX32_C2
,
POS_A
=
POS_C1
,
POS_B
=
POS_C2
,
TILE
=
tile_words
,
BLOCK
=
block_words
,
USE_CG
=
use_cg
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
# TODO: threads can't be synced on triton kernel
final_hash
=
intermediate_hashes
.
sum
().
item
()
cur
=
partials
while
cur
.
numel
()
>
1
:
n_elems
=
cur
.
numel
()
grid2
=
(
triton
.
cdiv
(
n_elems
,
reduce_chunk
),)
nxt
=
torch
.
empty
(
grid2
[
0
],
dtype
=
torch
.
uint64
,
device
=
cur
.
device
)
add_tree_reduce_u64_kernel
[
grid2
](
cur
,
nxt
,
n_elems
,
CHUNK
=
reduce_chunk
)
cur
=
nxt
return
final_
hash
return
_
final_
splitmix64
(
int
(
cur
.
item
()))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment