Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3270a92
Commit
c3270a92
authored
Apr 22, 2026
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev
parents
feced2f1
0b7cc6cf
Changes
37
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2290 additions
and
8 deletions
+2290
-8
vllm/model_executor/layers/quantization/turboquant/__init__.py
...model_executor/layers/quantization/turboquant/__init__.py
+14
-0
vllm/model_executor/layers/quantization/turboquant/centroids.py
...odel_executor/layers/quantization/turboquant/centroids.py
+86
-0
vllm/model_executor/layers/quantization/turboquant/config.py
vllm/model_executor/layers/quantization/turboquant/config.py
+186
-0
vllm/model_executor/layers/quantization/turboquant/quantizer.py
...odel_executor/layers/quantization/turboquant/quantizer.py
+40
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+5
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+6
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+6
-0
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+4
-0
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+1
-1
vllm/v1/attention/backends/registry.py
vllm/v1/attention/backends/registry.py
+1
-0
vllm/v1/attention/backends/turboquant_attn.py
vllm/v1/attention/backends/turboquant_attn.py
+812
-0
vllm/v1/attention/ops/triton_turboquant_decode.py
vllm/v1/attention/ops/triton_turboquant_decode.py
+624
-0
vllm/v1/attention/ops/triton_turboquant_store.py
vllm/v1/attention/ops/triton_turboquant_store.py
+460
-0
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+7
-0
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+26
-0
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-6
No files found.
vllm/model_executor/layers/quantization/turboquant/__init__.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant: Near-optimal KV-cache quantization for vLLM.
PolarQuant compression: random rotation + per-coordinate Lloyd-Max
scalar quantization for keys, uniform quantization for values.
Reference: "TurboQuant: Online Vector Quantization with Near-optimal
Distortion Rate" (ICLR 2026), Zandieh et al.
"""
from
vllm.model_executor.layers.quantization.turboquant.config
import
TurboQuantConfig
__all__
=
[
"TurboQuantConfig"
]
vllm/model_executor/layers/quantization/turboquant/centroids.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Lloyd-Max optimal scalar quantizer for TurboQuant.
After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate approximately follows N(0, 1/d) for d >= 64.
We solve the Lloyd-Max conditions to find optimal centroids.
Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.)
"""
import
math
from
functools
import
lru_cache
import
torch
def
_gaussian_pdf
(
x
:
float
,
sigma2
:
float
)
->
float
:
return
(
1.0
/
math
.
sqrt
(
2
*
math
.
pi
*
sigma2
))
*
math
.
exp
(
-
x
*
x
/
(
2
*
sigma2
))
def
_trapz
(
f
,
a
:
float
,
b
:
float
,
n
:
int
=
200
)
->
float
:
"""Trapezoidal numerical integration (replaces scipy.integrate.quad)."""
h
=
(
b
-
a
)
/
n
result
=
0.5
*
(
f
(
a
)
+
f
(
b
))
for
i
in
range
(
1
,
n
):
result
+=
f
(
a
+
i
*
h
)
return
result
*
h
def
solve_lloyd_max
(
d
:
int
,
bits
:
int
,
max_iter
:
int
=
200
,
tol
:
float
=
1e-10
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution.
Args:
d: Vector dimension (determines variance = 1/d).
bits: Number of quantization bits.
max_iter: Maximum Lloyd-Max iterations.
tol: Convergence tolerance.
Returns:
centroids: Sorted tensor of 2^bits optimal centroids.
boundaries: Sorted tensor of 2^bits - 1 decision boundaries.
"""
n_levels
=
2
**
bits
sigma2
=
1.0
/
d
sigma
=
math
.
sqrt
(
sigma2
)
def
pdf
(
x
):
return
_gaussian_pdf
(
x
,
sigma2
)
lo
,
hi
=
-
3.5
*
sigma
,
3.5
*
sigma
centroids
=
[
lo
+
(
hi
-
lo
)
*
(
i
+
0.5
)
/
n_levels
for
i
in
range
(
n_levels
)]
for
_
in
range
(
max_iter
):
boundaries
=
[
(
centroids
[
i
]
+
centroids
[
i
+
1
])
/
2.0
for
i
in
range
(
n_levels
-
1
)
]
edges
=
[
lo
*
3
]
+
boundaries
+
[
hi
*
3
]
new_centroids
=
[]
for
i
in
range
(
n_levels
):
a
,
b
=
edges
[
i
],
edges
[
i
+
1
]
num
=
_trapz
(
lambda
x
:
x
*
pdf
(
x
),
a
,
b
)
den
=
_trapz
(
pdf
,
a
,
b
)
new_centroids
.
append
(
num
/
den
if
den
>
1e-15
else
centroids
[
i
])
if
max
(
abs
(
new_centroids
[
i
]
-
centroids
[
i
])
for
i
in
range
(
n_levels
))
<
tol
:
break
centroids
=
new_centroids
boundaries
=
[(
centroids
[
i
]
+
centroids
[
i
+
1
])
/
2.0
for
i
in
range
(
n_levels
-
1
)]
return
(
torch
.
tensor
(
centroids
,
dtype
=
torch
.
float32
),
torch
.
tensor
(
boundaries
,
dtype
=
torch
.
float32
),
)
@
lru_cache
(
maxsize
=
32
)
def
get_centroids
(
d
:
int
,
bits
:
int
)
->
torch
.
Tensor
:
"""Get precomputed Lloyd-Max centroids (cached)."""
centroids
,
_
=
solve_lloyd_max
(
d
,
bits
)
return
centroids
vllm/model_executor/layers/quantization/turboquant/config.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant configuration."""
import
math
from
dataclasses
import
dataclass
# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
TQ_PRESETS
:
dict
[
str
,
dict
]
=
{
"turboquant_k8v4"
:
{
"key_quant_bits"
:
8
,
"value_quant_bits"
:
4
,
"norm_correction"
:
False
,
},
"turboquant_4bit_nc"
:
{
"key_quant_bits"
:
4
,
"value_quant_bits"
:
4
,
"norm_correction"
:
True
,
},
"turboquant_k3v4_nc"
:
{
"key_quant_bits"
:
3
,
"value_quant_bits"
:
4
,
"norm_correction"
:
True
,
},
"turboquant_3bit_nc"
:
{
"key_quant_bits"
:
3
,
"value_quant_bits"
:
3
,
"norm_correction"
:
True
,
},
}
@
dataclass
class
TurboQuantConfig
:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim
:
int
=
128
key_quant_bits
:
int
=
3
# 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits
:
int
=
4
# 3-4 = uniform quantized values
seed
:
int
=
42
norm_correction
:
bool
=
False
@
property
def
key_fp8
(
self
)
->
bool
:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return
self
.
key_quant_bits
==
8
@
property
def
mse_bits
(
self
)
->
int
:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if
self
.
key_fp8
:
return
self
.
value_quant_bits
return
self
.
key_quant_bits
@
property
def
key_mse_bits
(
self
)
->
int
:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if
self
.
key_fp8
:
return
0
return
self
.
key_quant_bits
@
property
def
centroid_bits
(
self
)
->
int
:
"""Bits for centroid generation — always non-zero."""
return
self
.
mse_bits
@
property
def
n_centroids
(
self
)
->
int
:
return
2
**
self
.
mse_bits
@
property
def
key_packed_size
(
self
)
->
int
:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
- res_norm: 2 bytes (float16)
"""
if
self
.
key_fp8
:
return
self
.
head_dim
# 1 byte per element
mse_bytes
=
math
.
ceil
(
self
.
head_dim
*
self
.
key_mse_bits
/
8
)
norm_bytes
=
4
# 2x float16
return
mse_bytes
+
norm_bytes
@
property
def
effective_value_quant_bits
(
self
)
->
int
:
"""Actual bits used for value storage."""
return
self
.
value_quant_bits
@
property
def
value_packed_size
(
self
)
->
int
:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes
=
math
.
ceil
(
self
.
head_dim
*
self
.
value_quant_bits
/
8
)
return
data_bytes
+
4
# +2 scale(fp16) +2 zero(fp16)
@
property
def
slot_size
(
self
)
->
int
:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return
self
.
key_packed_size
+
self
.
value_packed_size
@
property
def
slot_size_aligned
(
self
)
->
int
:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s
=
self
.
slot_size
return
s
+
(
s
%
2
)
# round up to even
@
staticmethod
def
get_boundary_skip_layers
(
num_layers
:
int
,
n
:
int
=
2
)
->
list
[
str
]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if
n
<=
0
or
num_layers
<=
0
:
return
[]
n
=
min
(
n
,
num_layers
//
2
)
# don't skip more than half
first
=
list
(
range
(
n
))
last
=
list
(
range
(
num_layers
-
n
,
num_layers
))
# Deduplicate (if num_layers <= 2*n)
indices
=
sorted
(
set
(
first
+
last
))
return
[
str
(
i
)
for
i
in
indices
]
@
staticmethod
def
from_cache_dtype
(
cache_dtype
:
str
,
head_dim
:
int
)
->
"TurboQuantConfig"
:
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if
cache_dtype
not
in
TQ_PRESETS
:
valid
=
", "
.
join
(
TQ_PRESETS
.
keys
())
raise
ValueError
(
f
"Unknown TurboQuant cache dtype:
{
cache_dtype
!
r
}
. "
f
"Valid presets:
{
valid
}
"
)
preset
=
TQ_PRESETS
[
cache_dtype
]
return
TurboQuantConfig
(
head_dim
=
head_dim
,
key_quant_bits
=
preset
[
"key_quant_bits"
],
value_quant_bits
=
preset
[
"value_quant_bits"
],
norm_correction
=
preset
[
"norm_correction"
],
)
vllm/model_executor/layers/quantization/turboquant/quantizer.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TurboQuant quantizer utilities.
Serving path uses generate_wht_signs() for WHT rotation sign buffers.
generate_rotation_matrix() is retained for standalone benchmarks only.
Triton kernels handle all quantization, packing, and dequantization on GPU.
"""
import
torch
def
generate_rotation_matrix
(
d
:
int
,
seed
:
int
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
torch
.
Tensor
:
"""Generate Haar-distributed random orthogonal matrix via QR decomposition."""
gen
=
torch
.
Generator
(
device
=
"cpu"
)
gen
.
manual_seed
(
seed
)
G
=
torch
.
randn
(
d
,
d
,
generator
=
gen
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
Q
,
R
=
torch
.
linalg
.
qr
(
G
)
# Fix sign ambiguity for determinism
diag_sign
=
torch
.
sign
(
torch
.
diag
(
R
))
diag_sign
[
diag_sign
==
0
]
=
1.0
Q
=
Q
*
diag_sign
.
unsqueeze
(
0
)
return
Q
.
to
(
device
)
def
generate_wht_signs
(
d
:
int
,
seed
:
int
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
torch
.
Tensor
:
"""Generate deterministic random ±1 signs for WHT rotation.
Used with Walsh-Hadamard Transform for per-layer rotation randomization.
Same seed derivation as QR (per-layer via seed + layer_idx * stride).
"""
gen
=
torch
.
Generator
(
device
=
"cpu"
)
gen
.
manual_seed
(
seed
)
bits
=
torch
.
randint
(
0
,
2
,
(
d
,),
generator
=
gen
,
device
=
"cpu"
)
signs
=
bits
.
float
()
*
2
-
1
return
signs
.
to
(
device
)
vllm/platforms/cuda.py
View file @
c3270a92
...
...
@@ -280,6 +280,11 @@ class CudaPlatformBase(Platform):
valid_backends_priorities
=
[]
invalid_reasons
=
{}
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
return
[(
AttentionBackendEnum
.
TURBOQUANT
,
0
)],
{}
backend_priorities
=
_get_backend_priorities
(
attn_selector_config
.
use_mla
,
device_capability
)
...
...
vllm/platforms/rocm.py
View file @
c3270a92
...
...
@@ -264,6 +264,12 @@ class RocmPlatform(Platform):
block_size
=
attn_selector_config
.
block_size
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
logger
.
info_once
(
"Using TurboQuant attention backend."
)
return
AttentionBackendEnum
.
TURBOQUANT
.
get_path
()
if
attn_selector_config
.
use_sparse
:
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
...
...
vllm/platforms/xpu.py
View file @
c3270a92
...
...
@@ -52,6 +52,12 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
# TurboQuant KV cache: route directly to TQ backend
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"turboquant_"
):
logger
.
info_once
(
"Using TurboQuant attention backend."
)
return
AttentionBackendEnum
.
TURBOQUANT
.
get_path
()
dtype
=
attn_selector_config
.
dtype
if
attn_selector_config
.
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on XPU."
)
...
...
vllm/utils/torch_utils.py
View file @
c3270a92
...
...
@@ -42,6 +42,10 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"int8"
:
torch
.
int8
,
"fp8_inc"
:
torch
.
float8_e4m3fn
,
"fp8_ds_mla"
:
torch
.
uint8
,
"turboquant_k8v4"
:
torch
.
uint8
,
"turboquant_4bit_nc"
:
torch
.
uint8
,
"turboquant_k3v4_nc"
:
torch
.
uint8
,
"turboquant_3bit_nc"
:
torch
.
uint8
,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
c3270a92
...
...
@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_
SINGLE_TOKEN_DECODE
AttentionCGSupport
.
UNIFORM_
BATCH
)
reorder_batch_threshold
:
int
=
1
...
...
vllm/v1/attention/backends/registry.py
View file @
c3270a92
...
...
@@ -78,6 +78,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN
=
"vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
TURBOQUANT
=
"vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
# set to None to avoid alias with other backend, whose value is an empty string
CUSTOM
=
None
...
...
vllm/v1/attention/backends/turboquant_attn.py
0 → 100644
View file @
c3270a92
This diff is collapsed.
Click to expand it.
vllm/v1/attention/ops/triton_turboquant_decode.py
0 → 100644
View file @
c3270a92
This diff is collapsed.
Click to expand it.
vllm/v1/attention/ops/triton_turboquant_store.py
0 → 100644
View file @
c3270a92
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused Triton kernels for TurboQuant KV store.
Two kernels:
1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization.
2. _tq_fused_store_mse: Fused bucketize + centroid gather + residual norm
+ MSE index packing + value quantization (eliminates 4 PyTorch kernel
launches vs the old pack-only approach).
The launcher `triton_turboquant_store` selects the appropriate kernel.
"""
import
math
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.ops.triton_turboquant_decode
import
_use_fp8_e4b15
# ═══════════════════════════════════════════════════════════════════════
# Shared: value uniform quantization + pack + scale/zero store
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
# pid * D offset into Value_ptr
slot_base
,
# byte offset into KV_cache_ptr for this slot+head
d_offs
,
# tl.arange(0, BLOCK_D)
d_mask
,
# d_offs < D
D
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
,
):
"""Uniform quantization of values to VQB bits, pack, and store with scale/zero."""
val_cache_offset
=
KPS
if
VQB
==
3
:
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
7.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
q_vals
=
tl
.
minimum
(
tl
.
maximum
(((
val_vec
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
7
)
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
q_grp
=
tl
.
reshape
(
q_vals
,
[
BLOCK_GRP
,
8
])
shifts_3bit
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
(
q_grp
<<
shifts_3bit
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
,
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
else
:
# VQB == 4
val_vec
=
tl
.
load
(
Value_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
val_min
=
tl
.
min
(
tl
.
where
(
d_mask
,
val_vec
,
float
(
"inf"
)),
axis
=
0
)
val_max
=
tl
.
max
(
tl
.
where
(
d_mask
,
val_vec
,
-
float
(
"inf"
)),
axis
=
0
)
v_scale
=
(
val_max
-
val_min
)
/
15.0
v_scale
=
tl
.
where
(
v_scale
>
1e-8
,
v_scale
,
1e-8
)
val_offs
=
tl
.
arange
(
0
,
BLOCK_VAL
)
val_mask
=
val_offs
<
VAL_DATA_BYTES
v0
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
,
mask
=
val_mask
&
(
val_offs
*
2
<
D
),
other
=
val_min
,
)
v1
=
tl
.
load
(
Value_ptr
+
base
+
val_offs
*
2
+
1
,
mask
=
val_mask
&
(
val_offs
*
2
+
1
<
D
),
other
=
val_min
,
)
q0
=
tl
.
minimum
(
tl
.
maximum
(((
v0
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
q1
=
tl
.
minimum
(
tl
.
maximum
(((
v1
-
val_min
)
/
v_scale
+
0.5
).
to
(
tl
.
int32
),
0
),
15
)
packed_val
=
(
q0
|
(
q1
<<
4
)).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
val_cache_offset
+
val_offs
,
packed_val
,
mask
=
val_mask
,
)
sc_offset
=
val_cache_offset
+
VAL_DATA_BYTES
sc_f16
=
v_scale
.
to
(
tl
.
float16
)
sc_u16
=
sc_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
,
(
sc_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
1
,
((
sc_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
zr_f16
=
val_min
.
to
(
tl
.
float16
)
zr_u16
=
zr_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
2
,
(
zr_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
sc_offset
+
3
,
((
zr_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
),
)
# ═══════════════════════════════════════════════════════════════════════
# FP8 key store + value uniform quantization
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_fp8
(
Key_ptr
,
# [NH, D] float16/bfloat16 — raw keys
Value_ptr
,
# [NH, D] float16/bfloat16 — raw values
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides (for computing byte offsets)
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
FP8_E4B15
:
tl
.
constexpr
=
0
,
# 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
):
"""FP8 key cast+scatter + value uniform quantization: one program per (token, head)."""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
# ── FP8 KEY: cast to FP8 in-kernel and store ─────────────────
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
k_vals
=
tl
.
load
(
Key_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
if
FP8_E4B15
:
k_fp8
=
k_vals
.
to
(
tl
.
float8e4b15
)
else
:
x_f32
=
k_vals
.
to
(
tl
.
float32
)
k_fp8
=
x_f32
.
to
(
tl
.
float8e4nv
)
k_bytes
=
k_fp8
.
to
(
tl
.
uint8
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
d_offs
,
k_bytes
,
mask
=
d_mask
)
# ── VALUE QUANTIZE + PACK ───────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Fused MSE store: bucketize + centroid gather + residual norm + pack
# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel)
# ═══════════════════════════════════════════════════════════════════════
@
triton
.
jit
def
_tq_fused_store_mse
(
# Post-rotation inputs
Y_ptr
,
# [NH, D] float32 — rotated normalized keys (x_hat @ PiT)
Norms_ptr
,
# [NH] float32 — key vector norms (||k||)
Value_ptr
,
# [NH, D] float32 — raw values
# Quantization tables
Centroids_ptr
,
# [n_centroids] float32
Midpoints_ptr
,
# [n_centroids-1] float32
# Cache and indexing
KV_cache_ptr
,
# [total_bytes] uint8 (flattened view)
Slot_mapping_ptr
,
# [N] int32 — per-token slot indices
# Cache strides
stride_cache_block
:
tl
.
constexpr
,
stride_cache_pos
:
tl
.
constexpr
,
stride_cache_head
:
tl
.
constexpr
,
# Dimensions
D
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# TQ layout
MSE_BYTES
:
tl
.
constexpr
,
KPS
:
tl
.
constexpr
,
# Value quantization
VQB
:
tl
.
constexpr
,
VAL_DATA_BYTES
:
tl
.
constexpr
,
# Packing block sizes
BLOCK_VAL
:
tl
.
constexpr
,
# MSE params
MSE_BITS
:
tl
.
constexpr
,
N_CENTROIDS
:
tl
.
constexpr
,
BLOCK_GRP
:
tl
.
constexpr
=
16
,
):
"""Fused MSE quantize + pack + store.
Performs bucketize, centroid gather, residual norm, MSE index packing,
and value quantization in one kernel — eliminates 4 PyTorch kernel
launches (bucketize, gather, subtract, norm) per layer vs pack-only.
"""
pid
=
tl
.
program_id
(
0
)
token_idx
=
pid
//
H
head_idx
=
pid
%
H
slot
=
tl
.
load
(
Slot_mapping_ptr
+
token_idx
)
if
slot
<
0
:
return
blk
=
slot
//
BLOCK_SIZE
off
=
slot
%
BLOCK_SIZE
slot_base
=
(
blk
*
stride_cache_block
+
off
*
stride_cache_pos
+
head_idx
*
stride_cache_head
)
base
=
pid
*
D
d_offs
=
tl
.
arange
(
0
,
BLOCK_D
)
d_mask
=
d_offs
<
D
# ── 1. INLINE BUCKETIZE ──────────────────────────────────────────
y_vec
=
tl
.
load
(
Y_ptr
+
base
+
d_offs
,
mask
=
d_mask
,
other
=
0.0
)
idx
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
int32
)
for
i
in
range
(
N_CENTROIDS
-
1
):
mid_val
=
tl
.
load
(
Midpoints_ptr
+
i
)
idx
+=
tl
.
where
(
y_vec
>=
mid_val
,
1
,
0
)
# ── 2. CENTROID GATHER + RESIDUAL NORM ────────────────────────────
centroid_vals
=
tl
.
load
(
Centroids_ptr
+
idx
,
mask
=
d_mask
,
other
=
0.0
)
residual
=
y_vec
-
centroid_vals
gamma
=
tl
.
sqrt
(
tl
.
sum
(
tl
.
where
(
d_mask
,
residual
*
residual
,
0.0
),
axis
=
0
))
# ── 3. PACK MSE INDICES from register idx ─────────────────────────
if
MSE_BITS
==
4
:
idx_pairs
=
tl
.
reshape
(
idx
,
[
BLOCK_D
//
2
,
2
])
shifts_4
=
tl
.
arange
(
0
,
2
)
*
4
packed
=
tl
.
sum
((
idx_pairs
&
0xF
)
<<
shifts_4
[
None
,
:],
axis
=
1
).
to
(
tl
.
uint8
)
mse_offs
=
tl
.
arange
(
0
,
BLOCK_D
//
2
)
mse_mask
=
mse_offs
<
MSE_BYTES
tl
.
store
(
KV_cache_ptr
+
slot_base
+
mse_offs
,
packed
,
mask
=
mse_mask
)
elif
MSE_BITS
==
3
:
grp_offs
=
tl
.
arange
(
0
,
BLOCK_GRP
)
grp_mask
=
grp_offs
<
(
D
//
8
)
idx_grp
=
tl
.
reshape
(
idx
,
[
BLOCK_GRP
,
8
])
shifts_3
=
tl
.
arange
(
0
,
8
)
*
3
packed_24
=
tl
.
sum
((
idx_grp
&
0x7
)
<<
shifts_3
[
None
,
:],
axis
=
1
)
b0
=
(
packed_24
&
0xFF
).
to
(
tl
.
uint8
)
b1
=
((
packed_24
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
b2
=
((
packed_24
>>
16
)
&
0xFF
).
to
(
tl
.
uint8
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
,
b0
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
1
,
b1
,
mask
=
grp_mask
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
grp_offs
*
3
+
2
,
b2
,
mask
=
grp_mask
)
# ── 4. STORE NORMS (vec_norm + gamma as fp16) ─────────────────────
norm_offset
=
MSE_BYTES
vn_f16
=
tl
.
load
(
Norms_ptr
+
pid
).
to
(
tl
.
float16
)
vn_u16
=
vn_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
,
(
vn_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
1
,
((
vn_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
gm_f16
=
gamma
.
to
(
tl
.
float16
)
gm_u16
=
gm_f16
.
to
(
tl
.
uint16
,
bitcast
=
True
)
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
2
,
(
gm_u16
&
0xFF
).
to
(
tl
.
uint8
))
tl
.
store
(
KV_cache_ptr
+
slot_base
+
norm_offset
+
3
,
((
gm_u16
>>
8
)
&
0xFF
).
to
(
tl
.
uint8
)
)
# ── 5. VALUE QUANTIZE + PACK ──────────────────────────────────────
_store_quantized_value
(
Value_ptr
,
KV_cache_ptr
,
base
,
slot_base
,
d_offs
,
d_mask
,
D
=
D
,
KPS
=
KPS
,
VQB
=
VQB
,
VAL_DATA_BYTES
=
VAL_DATA_BYTES
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
BLOCK_GRP
,
)
# ═══════════════════════════════════════════════════════════════════════
# Launcher
# ═══════════════════════════════════════════════════════════════════════
def
triton_turboquant_store
(
key
:
torch
.
Tensor
,
# [N, H, D] — raw keys (post-RoPE)
value
:
torch
.
Tensor
,
# [N, H, D] — raw values
kv_cache
:
torch
.
Tensor
,
# [num_blocks, block_size, Hk, padded_slot] uint8
slot_mapping
:
torch
.
Tensor
,
# [N] int32
PiT
:
torch
.
Tensor
,
# [D, D] float32
centroids
:
torch
.
Tensor
,
# [n_centroids] float32
midpoints
:
torch
.
Tensor
,
# [n_centroids-1] float32
mse_bits
:
int
,
key_packed_size
:
int
,
value_quant_bits
:
int
,
key_fp8
:
bool
=
False
,
):
"""Launch TQ store kernel — FP8 uses _tq_fused_store_fp8, MSE uses _tq_fused_store_mse."""
N
,
H
,
D
=
key
.
shape
NH
=
N
*
H
block_size
=
kv_cache
.
shape
[
1
]
num_kv_heads
=
kv_cache
.
shape
[
2
]
padded_slot
=
kv_cache
.
shape
[
3
]
BLOCK_D
=
triton
.
next_power_of_2
(
D
)
mse_bytes
=
math
.
ceil
(
D
*
mse_bits
/
8
)
n_centroids
=
2
**
mse_bits
val_data_bytes
=
math
.
ceil
(
D
*
value_quant_bits
/
8
)
BLOCK_VAL
=
triton
.
next_power_of_2
(
val_data_bytes
)
# Cache strides
stride_block
=
block_size
*
num_kv_heads
*
padded_slot
stride_pos
=
num_kv_heads
*
padded_slot
stride_head
=
padded_slot
block_grp
=
triton
.
next_power_of_2
(
D
//
8
)
if
D
>=
8
else
1
# ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ──
if
key_fp8
:
k_flat
=
key
.
reshape
(
NH
,
D
).
contiguous
()
v_flat
=
value
.
reshape
(
NH
,
D
).
contiguous
()
fp8_e4b15
=
_use_fp8_e4b15
(
key
.
device
.
index
or
0
)
grid
=
(
NH
,)
_tq_fused_store_fp8
[
grid
](
k_flat
,
v_flat
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
BLOCK_GRP
=
block_grp
,
FP8_E4B15
=
fp8_e4b15
,
num_warps
=
4
,
num_stages
=
1
,
)
return
# ── MSE PATH: external GEMM + fused bucketize/pack kernel ──
# Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel)
k_flat
=
key
.
float
().
reshape
(
NH
,
D
)
norms
=
k_flat
.
norm
(
dim
=
1
,
keepdim
=
True
)
x_hat
=
k_flat
/
(
norms
+
1e-8
)
y
=
(
x_hat
@
PiT
).
contiguous
()
v_flat
=
value
.
float
().
reshape
(
NH
,
D
)
# Fused kernel: bucketize + centroid gather + residual norm + pack
grid
=
(
NH
,)
_tq_fused_store_mse
[
grid
](
y
,
norms
.
squeeze
(
1
),
v_flat
,
centroids
,
midpoints
,
kv_cache
.
view
(
-
1
),
slot_mapping
,
stride_cache_block
=
stride_block
,
stride_cache_pos
=
stride_pos
,
stride_cache_head
=
stride_head
,
D
=
D
,
H
=
H
,
BLOCK_SIZE
=
block_size
,
BLOCK_D
=
BLOCK_D
,
MSE_BYTES
=
mse_bytes
,
KPS
=
key_packed_size
,
VQB
=
value_quant_bits
,
VAL_DATA_BYTES
=
val_data_bytes
,
BLOCK_VAL
=
BLOCK_VAL
,
MSE_BITS
=
mse_bits
,
N_CENTROIDS
=
n_centroids
,
BLOCK_GRP
=
block_grp
,
num_warps
=
4
,
num_stages
=
1
,
)
vllm/v1/core/single_type_kv_cache_manager.py
View file @
c3270a92
...
...
@@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec
,
SinkFullAttentionSpec
,
SlidingWindowSpec
,
TQFullAttentionSpec
,
)
from
vllm.v1.request
import
Request
...
...
@@ -51,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_pool
=
block_pool
self
.
enable_caching
=
enable_caching
# self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
...
...
@@ -204,6 +206,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv
(
num_total_computed_tokens
,
self
.
block_size
)
-
len
(
req_blocks
)
)
req_blocks
.
extend
(
allocated_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def
allocate_new_blocks
(
self
,
request_id
:
str
,
num_tokens
:
int
,
num_tokens_main_model
:
int
...
...
@@ -230,6 +234,8 @@ class SingleTypeKVCacheManager(ABC):
else
:
new_blocks
=
self
.
block_pool
.
get_new_blocks
(
num_new_blocks
)
req_blocks
.
extend
(
new_blocks
)
# if isinstance(self.kv_cache_spec, FullAttentionSpec):
# self.new_block_ids.extend(b.block_id for b in new_blocks)
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
...
...
@@ -1048,6 +1054,7 @@ class SinkFullAttentionManager(FullAttentionManager):
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
TQFullAttentionSpec
:
FullAttentionManager
,
MLAAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
ChunkedLocalAttentionSpec
:
ChunkedLocalAttentionManager
,
...
...
vllm/v1/kv_cache_interface.py
View file @
c3270a92
...
...
@@ -187,6 +187,32 @@ class FullAttentionSpec(AttentionSpec):
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
TQFullAttentionSpec
(
FullAttentionSpec
):
"""FullAttentionSpec with TQ-aware page size.
Python equivalent of the C++ TQ4FullAttentionSpec. Overrides
real_page_size_bytes to use TQ slot bytes instead of the raw
head_size * dtype formula.
"""
tq_slot_size
:
int
=
0
@
property
def
real_page_size_bytes
(
self
)
->
int
:
if
self
.
tq_slot_size
>
0
:
return
self
.
block_size
*
self
.
num_kv_heads
*
self
.
tq_slot_size
return
super
().
real_page_size_bytes
@
classmethod
def
merge
(
cls
,
specs
:
list
[
Self
])
->
Self
:
merged
=
super
().
merge
(
specs
)
assert
all
(
s
.
tq_slot_size
==
specs
[
0
].
tq_slot_size
for
s
in
specs
),
(
"All TQ layers in the same KV cache group must use the same tq_slot_size."
)
return
replace
(
merged
,
tq_slot_size
=
specs
[
0
].
tq_slot_size
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
class
MLAAttentionSpec
(
FullAttentionSpec
):
# TODO(Lucas/Chen): less hacky way to do this
...
...
vllm/v1/worker/dp_utils.py
View file @
c3270a92
...
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
...
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
:
# Early exit.
return
False
,
None
,
cudagraph_mode
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c3270a92
...
...
@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs
,
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.utils.torch_utils
import
async_tensor_h2d
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
...
@@ -5117,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
...
...
@@ -5234,9 +5232,15 @@ class GPUModelRunner(
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices_device
=
torch
.
from_numpy
(
logit_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices
=
logit_indices
.
tolist
()
logit_indices_device
=
async_tensor_h2d
(
logit_indices
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
return
hidden_states
,
hidden_states
[
logit_indices_device
]
@
torch
.
inference_mode
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment