Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc48ba0c
Unverified
Commit
dc48ba0c
authored
Sep 26, 2025
by
Bram Wasti
Committed by
GitHub
Sep 26, 2025
Browse files
Kernel-override Determinism [1/n] (#25603)
Signed-off-by:
Bram Wasti
<
bwasti@meta.com
>
parent
4778b426
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
890 additions
and
4 deletions
+890
-4
csrc/core/batch_invariant.hpp
csrc/core/batch_invariant.hpp
+16
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+6
-2
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+4
-1
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+3
-1
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+290
-0
vllm/model_executor/layers/batch_invariant.py
vllm/model_executor/layers/batch_invariant.py
+561
-0
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+7
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-0
No files found.
csrc/core/batch_invariant.hpp
0 → 100644
View file @
dc48ba0c
#pragma once
#include <cstdlib>
#include <string>
#include <cctype>
namespace
vllm
{
// vllm_kernel_override_batch_invariant(); returns true
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
inline
bool
vllm_kernel_override_batch_invariant
()
{
std
::
string
env_key
=
"VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
;
const
char
*
val
=
std
::
getenv
(
env_key
.
c_str
());
return
(
val
&&
std
::
atoi
(
val
)
!=
0
)
?
1
:
0
;
}
}
// namespace vllm
csrc/layernorm_kernels.cu
View file @
dc48ba0c
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_kernel_override_batch_invariant
();
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
@@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_kernel_override_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
...
...
csrc/layernorm_quant_kernels.cu
View file @
dc48ba0c
...
...
@@ -9,6 +9,7 @@
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_kernel_override_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
input_stride
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
dc48ba0c
...
...
@@ -21,6 +21,7 @@
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#include "../core/batch_invariant.hpp"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
...
@@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
using
Constants
=
detail
::
TopkConstants
<
EXPERTS
,
BYTES_PER_LDG
,
WARP_SIZE_PARAM
>
;
static
constexpr
int
VPT
=
Constants
::
VPT
;
static
constexpr
int
ROWS_PER_WARP
=
Constants
::
ROWS_PER_WARP
;
const
int
num_warps
=
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
bool
batch_invariant_launch
=
vllm
::
vllm_kernel_override_batch_invariant
();
const
int
num_warps
=
batch_invariant_launch
?
32
:
(
num_rows
+
ROWS_PER_WARP
-
1
)
/
ROWS_PER_WARP
;
const
int
num_blocks
=
(
num_warps
+
WARPS_PER_TB
-
1
)
/
WARPS_PER_TB
;
dim3
block_dim
(
WARP_SIZE_PARAM
,
WARPS_PER_TB
);
...
...
tests/v1/generation/test_batch_invariance.py
0 → 100644
View file @
dc48ba0c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
os
import
random
import
string
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
# Lightweight random prompt generator to vary prompt lengths and content.
vocab
=
[
"alpha"
,
"bravo"
,
"charlie"
,
"delta"
,
"echo"
,
"foxtrot"
,
"golf"
,
"hotel"
,
"india"
,
"juliet"
,
"kilo"
,
"lima"
,
"mike"
,
"november"
,
"oscar"
,
"papa"
,
"quebec"
,
"romeo"
,
"sierra"
,
"tango"
,
"uniform"
,
"victor"
,
"whiskey"
,
"xray"
,
"yankee"
,
"zulu"
,
]
n
=
random
.
randint
(
min_words
,
max_words
)
words
=
random
.
choices
(
vocab
,
k
=
n
)
# Add some noise and punctuation variability
if
random
.
random
()
<
0.5
:
words
[
0
]
=
words
[
0
].
capitalize
()
if
random
.
random
()
<
0.2
:
words
.
append
(
""
.
join
(
random
.
choices
(
string
.
ascii_lowercase
,
k
=
5
)))
punct
=
random
.
choice
([
"."
,
"?"
,
"!"
,
"..."
,
""
])
return
" "
.
join
(
words
)
+
punct
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
using the high-level v1 LLM() API only (no manual batching).
Strategy:
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
- Compute a baseline output for the needle prompt with the bs=1 engine.
- For many trials, generate a batch (size N) where the needle appears at a
random position among random filler prompts using the bs=N engine.
- Track how many trials match vs mismatch, and report totals at the end.
The test fails if any mismatches occur, but we still dump pass/fail
counts.
Notes:
- Use seeded stochastic sampling with a fixed seed to test determinism.
- Outputs are intentionally longer and sampled at higher temperature/top_p
to produce a more random-sounding phrase, yet remain deterministic by
seed.
- Keep max_tokens and max_model_len bounded for speed and memory use.
"""
random
.
seed
(
12345
)
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
num_trials
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_TRIALS"
,
"5"
))
batch_size
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_BATCH_SIZE"
,
"64"
))
assert
batch_size
>=
2
,
"Batch size should be >= 2 to mix needle."
# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util
=
float
(
os
.
getenv
(
"VLLM_GPU_MEMORY_UTILIZATION"
,
"0.3"
))
max_model_len
=
int
(
os
.
getenv
(
"VLLM_MAX_MODEL_LEN"
,
"4096"
))
swap_space_gb
=
int
(
os
.
getenv
(
"VLLM_SWAP_SPACE_GB"
,
"4"
))
# Sampling parameters: longer outputs with a more random-sounding
# continuation,but still deterministic due to fixed seed.
temperature
=
float
(
os
.
getenv
(
"VLLM_NEEDLE_TEMPERATURE"
,
"0.0"
))
top_p
=
float
(
os
.
getenv
(
"VLLM_NEEDLE_TOP_P"
,
"0.95"
))
max_tokens
=
int
(
os
.
getenv
(
"VLLM_NEEDLE_MAX_TOKENS"
,
"128"
))
sampling
=
SamplingParams
(
temperature
=
temperature
,
top_p
=
top_p
,
max_tokens
=
max_tokens
,
seed
=
20240919
,
)
needle_prompt
=
(
"There once was a "
)
llm_bs1
=
None
llm_bsN
=
None
try
:
# Engine with bs=1 behavior
llm_bs1
=
LLM_with_max_seqs
(
model
=
model
,
max_num_seqs
=
1
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
swap_space
=
swap_space_gb
,
)
# Baseline generation for the needle prompt alone.
baseline_out
=
llm_bs1
.
generate
([
needle_prompt
],
sampling
)
assert
len
(
baseline_out
)
==
1
assert
len
(
baseline_out
[
0
].
outputs
)
>=
1
baseline_text
=
baseline_out
[
0
].
outputs
[
0
].
text
# Engine with larger batch limit (e.g., 64)
llm_bsN
=
LLM_with_max_seqs
(
model
=
model
,
max_num_seqs
=
batch_size
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
swap_space
=
swap_space_gb
,
)
mismatches
=
0
for
trial
in
range
(
num_trials
):
# Create a batch of size `batch_size` and insert the needle at
# a random index
prompts
:
list
[
str
]
=
[]
needle_pos
=
random
.
randint
(
0
,
batch_size
-
1
)
for
i
in
range
(
batch_size
):
if
i
==
needle_pos
:
prompts
.
append
(
needle_prompt
)
else
:
prompts
.
append
(
_random_prompt
())
# Generate with the larger-batch engine
outputs
=
llm_bsN
.
generate
(
prompts
,
sampling
)
# Find the needle output by position
needle_output
=
outputs
[
needle_pos
]
assert
needle_output
.
prompt
==
needle_prompt
assert
len
(
needle_output
.
outputs
)
>=
1
text
=
needle_output
.
outputs
[
0
].
text
if
text
!=
baseline_text
:
mismatches
+=
1
passes
=
num_trials
-
mismatches
# Dump how many passed vs failed
print
(
f
"[determinism] total=
{
num_trials
}
, passed=
{
passes
}
, "
f
"failed=
{
mismatches
}
, batch_size=
{
batch_size
}
"
)
if
mismatches
>
0
:
pytest
.
fail
(
f
"Nondeterministic outputs detected:
{
mismatches
}
failed out "
f
"of
{
num_trials
}
trials (batch_size=
{
batch_size
}
)."
)
finally
:
# Ensure engines are shutdown to free GPU/VRAM across test sessions
if
llm_bs1
is
not
None
:
with
contextlib
.
suppress
(
Exception
):
llm_bs1
.
shutdown
()
if
llm_bsN
is
not
None
:
with
contextlib
.
suppress
(
Exception
):
llm_bsN
.
shutdown
()
def
_extract_step_logprobs
(
request_output
):
if
getattr
(
request_output
,
"outputs"
,
None
):
inner
=
request_output
.
outputs
[
0
]
if
hasattr
(
inner
,
"logprobs"
)
and
inner
.
logprobs
is
not
None
:
t
=
torch
.
tensor
(
[
inner
.
logprobs
[
i
][
tid
].
logprob
for
i
,
tid
in
enumerate
(
inner
.
token_ids
)
],
dtype
=
torch
.
float32
,
)
return
t
return
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA to match production inference path."
,
)
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bs2
():
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
model_name
=
os
.
getenv
(
"VLLM_TEST_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_TEST_TP_SIZE"
,
"1"
))
# Force float32 to avoid precision-induced differences.
llm
=
LLM
(
model
=
model_name
,
tensor_parallel_size
=
tp_size
,
enforce_eager
=
True
,
# helps reduce nondeterminism from some backends
)
prompts
=
[
"The capital of France is"
,
"The capital of Germany is"
,
]
sp
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
,
max_tokens
=
8
,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
seed
=
1234
,
logprobs
=
5
,
)
# BS=1: run prompts individually and collect logprobs per step.
bs1_logprobs_per_prompt
=
[]
for
p
in
prompts
:
outs
=
llm
.
generate
([
p
],
sp
,
use_tqdm
=
False
)
assert
len
(
outs
)
==
1
step_logprobs
=
_extract_step_logprobs
(
outs
[
0
])
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs1_logprobs_per_prompt
.
append
(
step_logprobs
)
# BS=2: run prompts in a batch and collect logprobs per step for each
# prompt.
outs_batched
=
llm
.
generate
(
prompts
,
sp
,
use_tqdm
=
False
)
assert
len
(
outs_batched
)
==
len
(
prompts
)
bs2_logprobs_per_prompt
=
[]
for
o
in
outs_batched
:
step_logprobs
=
_extract_step_logprobs
(
o
)
if
step_logprobs
is
None
:
pytest
.
skip
(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs2_logprobs_per_prompt
.
append
(
step_logprobs
)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
for
i
,
(
logprobs_bs1
,
logprobs_bs2
)
in
enumerate
(
zip
(
bs1_logprobs_per_prompt
,
bs2_logprobs_per_prompt
)):
assert
len
(
logprobs_bs1
)
==
len
(
logprobs_bs2
),
(
f
"Different number of generation steps for prompt index
{
i
}
: "
f
"
{
len
(
logprobs_bs1
)
}
(BS=1) vs
{
len
(
logprobs_bs2
)
}
(BS=2)"
)
for
t
,
(
a
,
b
)
in
enumerate
(
zip
(
logprobs_bs1
,
logprobs_bs2
)):
assert
a
.
shape
==
b
.
shape
,
(
f
"Logits shape mismatch at prompt
{
i
}
, step
{
t
}
: "
f
"
{
a
.
shape
}
vs
{
b
.
shape
}
"
)
# Bitwise exact equality.
assert
torch
.
equal
(
a
,
b
),
(
f
"Bitwise logprobs mismatch at prompt
{
i
}
, step
{
t
}
"
f
"(dtype=
{
a
.
dtype
}
, shape=
{
a
.
shape
}
)."
)
def
LLM_with_max_seqs
(
model
:
str
,
max_num_seqs
:
int
,
gpu_memory_utilization
:
float
,
max_model_len
:
int
,
swap_space
:
int
,
)
->
LLM
:
"""
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
using the high-level v1 LLM API, while constraining memory usage.
"""
return
LLM
(
model
=
model
,
max_num_seqs
=
max_num_seqs
,
# Constrain GPU memory pool so test can run even on busy GPUs.
gpu_memory_utilization
=
gpu_memory_utilization
,
# Keep KV cache footprint small while allowing longer outputs.
max_model_len
=
max_model_len
,
# Allow some CPU offload if needed.
swap_space
=
swap_space
,
# Keep things lean and CI-friendly.
dtype
=
"float16"
,
# Single-GPU by default; override externally if desired.
tensor_parallel_size
=
int
(
os
.
getenv
(
"VLLM_TP_SIZE"
,
"1"
)),
trust_remote_code
=
os
.
getenv
(
"VLLM_TRUST_REMOTE_CODE"
,
"0"
)
==
"1"
,
)
vllm/model_executor/layers/batch_invariant.py
0 → 100644
View file @
dc48ba0c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
os
from
collections
import
namedtuple
from
collections.abc
import
Callable
from
typing
import
Any
,
Union
import
torch
import
triton
import
triton.language
as
tl
def
_matmul_launch_metadata
(
grid
:
Callable
[...,
Any
],
kernel
:
Any
,
args
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
ret
=
{}
m
,
n
,
k
=
args
[
"M"
],
args
[
"N"
],
args
[
"K"
]
ret
[
"name"
]
=
f
"
{
kernel
.
name
}
[M=
{
m
}
, N=
{
n
}
, K=
{
k
}
]"
if
"tiles_per_update"
in
args
:
ret
[
"name"
]
=
(
f
"
{
kernel
.
name
}
[M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, "
f
"tiles_per_update=
{
args
[
'tiles_per_update'
]:
02
}
]"
)
if
"c_ptr"
in
args
:
bytes_per_elem
=
args
[
"c_ptr"
].
element_size
()
else
:
bytes_per_elem
=
1
if
args
[
"FP8_OUTPUT"
]
else
2
ret
[
f
"flops
{
bytes_per_elem
*
8
}
"
]
=
2.0
*
m
*
n
*
k
ret
[
"bytes"
]
=
bytes_per_elem
*
(
m
*
k
+
n
*
k
+
m
*
n
)
return
ret
@
triton
.
jit
def
_compute_pid
(
tile_id
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
):
group_id
=
tile_id
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
tile_id
%
group_size_m
)
pid_n
=
(
tile_id
%
num_pid_in_group
)
//
group_size_m
return
pid_m
,
pid_n
@
triton
.
jit
(
launch_metadata
=
_matmul_launch_metadata
)
def
matmul_kernel_persistent
(
a_ptr
,
b_ptr
,
c_ptr
,
#
bias_ptr
,
M
,
N
,
K
,
#
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
#
BLOCK_SIZE_N
:
tl
.
constexpr
,
#
BLOCK_SIZE_K
:
tl
.
constexpr
,
#
GROUP_SIZE_M
:
tl
.
constexpr
,
#
NUM_SMS
:
tl
.
constexpr
,
#
A_LARGE
:
tl
.
constexpr
,
B_LARGE
:
tl
.
constexpr
,
C_LARGE
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
start_pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
num_tiles
=
num_pid_m
*
num_pid_n
tile_id_c
=
start_pid
-
NUM_SMS
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
for
tile_id
in
tl
.
range
(
start_pid
,
num_tiles
,
NUM_SMS
,
flatten
=
True
):
pid_m
,
pid_n
=
_compute_pid
(
tile_id
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
start_m
=
pid_m
*
BLOCK_SIZE_M
start_n
=
pid_n
*
BLOCK_SIZE_N
offs_am
=
start_m
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_bn
=
start_n
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
A_LARGE
:
offs_am
=
offs_am
.
to
(
tl
.
int64
)
if
B_LARGE
:
offs_bn
=
offs_bn
.
to
(
tl
.
int64
)
offs_am
=
tl
.
where
(
offs_am
<
M
,
offs_am
,
0
)
offs_bn
=
tl
.
where
(
offs_bn
<
N
,
offs_bn
,
0
)
offs_am
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_am
,
BLOCK_SIZE_M
),
BLOCK_SIZE_M
)
offs_bn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_bn
,
BLOCK_SIZE_N
),
BLOCK_SIZE_N
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
ki
in
range
(
k_tiles
):
if
A_LARGE
or
B_LARGE
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
).
to
(
tl
.
int64
)
else
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k_for_mask
[
None
,
:]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k_for_mask
[:,
None
]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
tile_id_c
+=
NUM_SMS
pid_m
,
pid_n
=
_compute_pid
(
tile_id_c
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
C_LARGE
:
offs_cm
=
offs_cm
.
to
(
tl
.
int64
)
offs_cn
=
offs_cn
.
to
(
tl
.
int64
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
HAS_BIAS
:
bias_ptrs
=
bias_ptr
+
offs_cn
bias
=
tl
.
load
(
bias_ptrs
,
mask
=
offs_cn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
accumulator
+=
bias
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
else
:
c
=
accumulator
.
to
(
tl
.
float16
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
matmul_persistent
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
Union
[
torch
.
Tensor
,
None
]
=
None
):
# Check constraints.
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"Incompatible dimensions"
assert
a
.
dtype
==
b
.
dtype
,
"Incompatible dtypes"
assert
bias
is
None
or
bias
.
dim
()
==
1
,
(
"Currently assuming bias is 1D, let Horace know if you run into this"
)
NUM_SMS
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
multi_processor_count
M
,
K
=
a
.
shape
K
,
N
=
b
.
shape
dtype
=
a
.
dtype
# Allocates output.
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
dtype
)
# 1D launch kernel where each block gets its own program.
def
grid
(
META
):
return
(
min
(
NUM_SMS
,
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
])),
)
configs
=
{
torch
.
bfloat16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float32
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
}
# print(a.device, b.device, c.device)
matmul_kernel_persistent
[
grid
](
a
,
b
,
c
,
#
bias
,
M
,
N
,
K
,
#
a
.
stride
(
0
),
a
.
stride
(
1
),
#
b
.
stride
(
0
),
b
.
stride
(
1
),
#
c
.
stride
(
0
),
c
.
stride
(
1
),
#
NUM_SMS
=
NUM_SMS
,
#
A_LARGE
=
a
.
numel
()
>
2
**
31
,
B_LARGE
=
b
.
numel
()
>
2
**
31
,
C_LARGE
=
c
.
numel
()
>
2
**
31
,
HAS_BIAS
=
bias
is
not
None
,
**
configs
[
dtype
],
)
return
c
@
triton
.
jit
def
_log_softmax_kernel
(
input_ptr
,
output_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Compute log_softmax along the last dimension of a 2D tensor.
Each block handles one row of the input tensor.
"""
# Get the row index for this block
row_idx
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
# Compute base pointers for input and output rows
row_start_ptr
=
input_ptr
+
row_idx
*
input_row_stride
output_row_start_ptr
=
output_ptr
+
row_idx
*
output_row_stride
# Step 1: Find maximum value in the row for numerical stability
max_val
=
-
float
(
"inf"
)
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=-
float
(
"inf"
))
# Update maximum
max_val
=
tl
.
max
(
tl
.
maximum
(
vals
,
max_val
))
# Step 2: Compute sum of exp(x - max_val)
sum_exp
=
0.0
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
# Compute exp(x - max_val) and accumulate
exp_vals
=
tl
.
exp
(
vals
-
max_val
)
sum_exp
+=
tl
.
sum
(
tl
.
where
(
mask
,
exp_vals
,
0.0
))
# Compute log(sum_exp)
log_sum_exp
=
tl
.
log
(
sum_exp
)
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
)
# Compute log_softmax
output
=
vals
-
max_val
-
log_sum_exp
# Store results
tl
.
store
(
output_row_start_ptr
+
col_idx
,
output
,
mask
=
mask
)
def
log_softmax
(
input
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Compute log_softmax using Triton kernel.
Args:
input: Input tensor
dim: Dimension along which to compute log_softmax
(only -1 or last dim supported)
>> Stashed changes
Returns:
Tensor with log_softmax applied along the specified dimension
"""
if
dim
!=
-
1
and
dim
!=
input
.
ndim
-
1
:
raise
ValueError
(
"This implementation only supports log_softmax along "
"the last dimension"
)
# Flatten all dimensions except the last one
original_shape
=
input
.
shape
input_2d
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input_2d
.
contiguous
()
n_rows
,
n_cols
=
input_2d
.
shape
# Allocate output tensor
output
=
torch
.
empty_like
(
input_2d
)
# Choose block size based on the number of columns
BLOCK_SIZE
=
1024
# Launch kernel with one block per row
grid
=
(
n_rows
,
)
_log_softmax_kernel
[
grid
](
input_2d
,
output
,
input_2d
.
stride
(
0
),
output
.
stride
(
0
),
n_cols
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Reshape output back to original shape
return
output
.
reshape
(
original_shape
)
@
triton
.
jit
def
mean_kernel
(
input_ptr
,
output_ptr
,
input_stride0
,
input_stride1
,
input_stride2
,
output_stride0
,
output_stride1
,
M
,
# size before reduction dim
N
,
# size of reduction dim
K
,
# size after reduction dim
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Kernel for computing mean along a single dimension.
Input is viewed as (M, N, K) where N is the dimension being reduced.
"""
# Program ID gives us which output element we're computing
pid
=
tl
.
program_id
(
0
)
# Compute output indices
m_idx
=
pid
//
K
k_idx
=
pid
%
K
# Bounds check
if
m_idx
>=
M
or
k_idx
>=
K
:
return
# Accumulate sum across reduction dimension
acc
=
0.0
for
n_start
in
range
(
0
,
N
,
BLOCK_SIZE
):
n_offsets
=
n_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
n_offsets
<
N
# Calculate input indices
input_idx
=
m_idx
*
input_stride0
+
n_offsets
*
input_stride1
\
+
k_idx
*
input_stride2
# Load and accumulate
vals
=
tl
.
load
(
input_ptr
+
input_idx
,
mask
=
mask
,
other
=
0.0
)
acc
+=
tl
.
sum
(
vals
)
# Compute mean and store
mean_val
=
acc
/
N
output_idx
=
m_idx
*
output_stride0
+
k_idx
*
output_stride1
tl
.
store
(
output_ptr
+
output_idx
,
mean_val
)
def
mean_dim
(
input
:
torch
.
Tensor
,
dim
:
int
,
keepdim
:
bool
=
False
,
dtype
:
Union
[
torch
.
dtype
,
None
]
=
None
)
->
torch
.
Tensor
:
"""
Triton implementation of torch.mean with single dimension reduction.
Args:
input: Input tensor
dim: Single dimension along which to compute mean
keepdim: Whether to keep the reduced dimension
dtype: Output dtype. If None, uses input dtype
(or float32 for integer inputs)
Returns:
Tensor with mean values along specified dimension
"""
# Validate inputs
assert
input
.
is_cuda
,
"Input must be a CUDA tensor"
assert
-
input
.
ndim
<=
dim
<
input
.
ndim
,
(
f
"Invalid dimension
{
dim
}
for tensor with
{
input
.
ndim
}
dimensions"
)
# Handle negative dim
if
dim
<
0
:
dim
=
dim
+
input
.
ndim
# Handle dtype
if
dtype
is
None
:
if
input
.
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
dtype
=
torch
.
float32
else
:
dtype
=
input
.
dtype
# Convert input to appropriate dtype if needed
if
input
.
dtype
!=
dtype
:
input
=
input
.
to
(
dtype
)
# Get input shape and strides
shape
=
list
(
input
.
shape
)
# Calculate dimensions for kernel
M
=
1
for
i
in
range
(
dim
):
M
*=
shape
[
i
]
N
=
shape
[
dim
]
K
=
1
for
i
in
range
(
dim
+
1
,
len
(
shape
)):
K
*=
shape
[
i
]
# Reshape input to 3D view (M, N, K)
input_3d
=
input
.
reshape
(
M
,
N
,
K
)
# Create output shape
if
keepdim
:
output_shape
=
shape
.
copy
()
output_shape
[
dim
]
=
1
else
:
output_shape
=
shape
[:
dim
]
+
shape
[
dim
+
1
:]
# Create output tensor
output
=
torch
.
empty
(
output_shape
,
dtype
=
dtype
,
device
=
input
.
device
)
# Reshape output for kernel
if
keepdim
:
output_2d
=
output
.
reshape
(
M
,
1
,
K
).
squeeze
(
1
)
else
:
output_2d
=
output
.
reshape
(
M
,
K
)
# Launch kernel
grid
=
(
M
*
K
,
)
BLOCK_SIZE
=
1024
mean_kernel
[
grid
](
input_3d
,
output_2d
,
input_3d
.
stride
(
0
),
input_3d
.
stride
(
1
),
input_3d
.
stride
(
2
),
output_2d
.
stride
(
0
),
output_2d
.
stride
(
1
)
if
output_2d
.
ndim
>
1
else
0
,
M
,
N
,
K
,
BLOCK_SIZE
,
)
return
output
def
mm_batch_invariant
(
a
,
b
):
return
matmul_persistent
(
a
,
b
)
def
addmm_batch_invariant
(
bias
,
a
,
b
):
return
matmul_persistent
(
a
,
b
,
bias
=
bias
)
def
_log_softmax_batch_invariant
(
input
,
dim
,
_half_to_float
):
assert
not
_half_to_float
,
"not implemented"
return
log_softmax
(
input
,
dim
=
dim
)
def
mean_batch_invariant
(
input
,
dim
,
keepdim
=
False
,
dtype
:
Union
[
torch
.
dtype
,
None
]
=
None
):
assert
dtype
is
None
or
dtype
==
torch
.
float32
,
\
f
"unsupported dtype:
{
dtype
}
"
result
=
input
.
to
(
torch
.
float32
)
# Sort dimensions to reduce from largest to smallest to handle shifting dims
# during iterative reduction.
sorted_dims
=
sorted
([
d
%
input
.
ndim
for
d
in
dim
],
reverse
=
True
)
# Iteratively apply a deterministic mean.
for
d
in
sorted_dims
:
result
=
mean_dim
(
result
,
dim
=
d
,
keepdim
=
True
)
if
not
keepdim
:
# Squeeze the reduced dimensions.
for
d
in
sorted_dims
:
result
=
result
.
squeeze
(
d
)
return
result
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
def
is_batch_invariant_mode_enabled
():
return
_batch_invariant_MODE
def
enable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
if
_batch_invariant_MODE
:
return
_batch_invariant_MODE
=
True
_batch_invariant_LIB
=
torch
.
library
.
Library
(
"aten"
,
"IMPL"
)
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::_log_softmax"
,
_log_softmax_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::mean.dim"
,
mean_batch_invariant
,
"CUDA"
)
def
disable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
if
_batch_invariant_LIB
is
not
None
:
_batch_invariant_LIB
.
_destroy
()
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
@
contextlib
.
contextmanager
def
set_batch_invariant_mode
(
enabled
:
bool
=
True
):
global
_batch_invariant_MODE
,
_batch_invariant_LIB
old_data
=
(
_batch_invariant_MODE
,
_batch_invariant_LIB
)
if
enabled
:
enable_batch_invariant_mode
()
else
:
disable_batch_invariant_mode
()
yield
if
_batch_invariant_LIB
is
not
None
:
_batch_invariant_LIB
.
_destroy
()
_batch_invariant_MODE
,
_batch_invariant_LIB
=
old_data
AttentionBlockSize
=
namedtuple
(
"AttentionBlockSize"
,
[
"block_m"
,
"block_n"
])
def
get_batch_invariant_attention_block_size
()
->
AttentionBlockSize
:
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
def
vllm_kernel_override_batch_invariant
():
env_key
=
"VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
is_overridden
=
False
val
=
os
.
getenv
(
env_key
,
"0"
)
try
:
is_overridden
=
int
(
val
)
!=
0
except
ValueError
:
is_overridden
=
False
return
is_overridden
def
init_batch_invariance
():
# this will hit all the csrc overrides as well
if
vllm_kernel_override_batch_invariant
():
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
"FLEX_ATTENTION"
enable_batch_invariant_mode
()
vllm/v1/attention/backends/flex_attention.py
View file @
dc48ba0c
...
...
@@ -18,6 +18,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_kernel_override_batch_invariant
)
from
vllm.utils
import
cdiv
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
...
...
@@ -839,6 +841,11 @@ def get_kernel_options(query, block_m, block_n,
kernel_options
:
dict
[
str
,
Union
[
int
,
bool
]]
=
{
"FORCE_USE_FLEX_ATTENTION"
:
True
,
}
if
vllm_kernel_override_batch_invariant
():
kernel_options
[
"BLOCK_M"
]
=
16
kernel_options
[
"BLOCK_N"
]
=
16
kernel_options
[
"IS_DIVISIBLE"
]
=
False
return
kernel_options
if
use_direct_build
:
kernel_options
[
"BLOCK_M"
]
=
block_m
kernel_options
[
"BLOCK_N"
]
=
block_n
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
dc48ba0c
...
...
@@ -192,6 +192,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
from
vllm.model_executor.layers.batch_invariant
import
(
init_batch_invariance
)
init_batch_invariance
()
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment