Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
626ccb7d
"apps/life_sci/python/dgllife/model/gnn/weave.py" did not exist on "28117cd96da17ffa4e5b08dd3e3cc04ba323a245"
Unverified
Commit
626ccb7d
authored
May 19, 2025
by
Mick
Committed by
GitHub
May 18, 2025
Browse files
vlm: tensor hash kernel (#5974)
parent
72bfb0ba
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
1 deletion
+73
-1
python/sglang/srt/layers/multimodal.py
python/sglang/srt/layers/multimodal.py
+70
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-1
No files found.
python/sglang/srt/layers/multimodal.py
0 → 100644
View file @
626ccb7d
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logits processing."""
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
hash_kernel
(
input_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
PRIME
:
tl
.
constexpr
,
XCONST
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
data
=
tl
.
load
(
input_ptr
+
offsets
,
mask
=
mask
,
other
=
0
)
mixed
=
data
^
(
offsets
+
XCONST
)
hash_val
=
mixed
*
PRIME
hash_val
=
hash_val
^
(
hash_val
>>
16
)
hash_val
=
hash_val
*
(
PRIME
^
XCONST
)
hash_val
=
hash_val
^
(
hash_val
>>
13
)
tl
.
store
(
output_ptr
+
offsets
,
hash_val
,
mask
=
mask
)
PRIME_1
=
-
(
11400714785074694791
^
0xFFFFFFFFFFFFFFFF
)
-
1
PRIME_2
=
-
(
14029467366897019727
^
0xFFFFFFFFFFFFFFFF
)
-
1
def
gpu_tensor_hash
(
tensor
:
torch
.
Tensor
)
->
int
:
assert
tensor
.
is_cuda
tensor
=
tensor
.
contiguous
().
view
(
torch
.
int32
)
n
=
tensor
.
numel
()
BLOCK_SIZE
=
1024
grid
=
(
triton
.
cdiv
(
n
,
BLOCK_SIZE
),)
intermediate_hashes
=
torch
.
empty
(
n
,
dtype
=
torch
.
int32
,
device
=
tensor
.
device
)
hash_kernel
[
grid
](
tensor
,
intermediate_hashes
,
n
,
BLOCK_SIZE
=
BLOCK_SIZE
,
PRIME
=
PRIME_1
,
XCONST
=
PRIME_2
,
)
# TODO: threads can't be synced on triton kernel
final_hash
=
intermediate_hashes
.
sum
().
item
()
return
final_hash
python/sglang/srt/managers/schedule_batch.py
View file @
626ccb7d
...
...
@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.disaggregation.base
import
BaseKVSender
from
sglang.srt.disaggregation.decode
import
ScheduleBatchDisaggregationDecodeMixin
from
sglang.srt.layers.multimodal
import
gpu_tensor_hash
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
...
...
@@ -222,7 +223,8 @@ class MultimodalDataItem:
for
x
in
tensor_list
]
tensor
=
torch
.
concat
(
tensor_list
)
if
tensor
.
is_cuda
:
return
gpu_tensor_hash
(
tensor
)
tensor
=
tensor
.
detach
().
contiguous
()
if
tensor
.
dtype
==
torch
.
bfloat16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment