Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86527a47
Unverified
Commit
86527a47
authored
Sep 21, 2025
by
Stefan He
Committed by
GitHub
Sep 21, 2025
Browse files
[deterministic inference] Move batch invariant pkg to sglang (#10695)
parent
134b4f7e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
577 additions
and
1 deletion
+577
-1
python/sglang/srt/batch_invariant_ops/__init__.py
python/sglang/srt/batch_invariant_ops/__init__.py
+27
-0
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
+549
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
No files found.
python/sglang/srt/batch_invariant_ops/__init__.py
0 → 100644
View file @
86527a47
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/__init__.py
from
.batch_invariant_ops
import
(
AttentionBlockSize
,
disable_batch_invariant_mode
,
enable_batch_invariant_mode
,
get_batch_invariant_attention_block_size
,
is_batch_invariant_mode_enabled
,
log_softmax
,
matmul_persistent
,
mean_dim
,
set_batch_invariant_mode
,
)
__version__
=
"0.1.0"
__all__
=
[
"set_batch_invariant_mode"
,
"is_batch_invariant_mode_enabled"
,
"disable_batch_invariant_mode"
,
"enable_batch_invariant_mode"
,
"matmul_persistent"
,
"log_softmax"
,
"mean_dim"
,
"get_batch_invariant_attention_block_size"
,
"AttentionBlockSize"
,
]
python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py
0 → 100644
View file @
86527a47
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
import
contextlib
from
collections
import
namedtuple
from
collections.abc
import
Callable
from
typing
import
Any
,
Dict
import
torch
import
triton
import
triton.language
as
tl
__all__
=
[
"set_batch_invariant_mode"
,
"is_batch_invariant_mode_enabled"
,
"disable_batch_invariant_mode"
,
"enable_batch_invariant_mode"
,
]
def
_matmul_launch_metadata
(
grid
:
Callable
[...,
Any
],
kernel
:
Any
,
args
:
Dict
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
ret
=
{}
m
,
n
,
k
=
args
[
"M"
],
args
[
"N"
],
args
[
"K"
]
ret
[
"name"
]
=
f
"
{
kernel
.
name
}
[M=
{
m
}
, N=
{
n
}
, K=
{
k
}
]"
if
"tiles_per_update"
in
args
:
ret
[
"name"
]
=
(
f
"
{
kernel
.
name
}
[M=
{
m
}
, N=
{
n
}
, K=
{
k
}
, tiles_per_update=
{
args
[
'tiles_per_update'
]:
02
}
]"
)
if
"c_ptr"
in
args
:
bytes_per_elem
=
args
[
"c_ptr"
].
element_size
()
else
:
bytes_per_elem
=
1
if
args
[
"FP8_OUTPUT"
]
else
2
ret
[
f
"flops
{
bytes_per_elem
*
8
}
"
]
=
2.0
*
m
*
n
*
k
ret
[
"bytes"
]
=
bytes_per_elem
*
(
m
*
k
+
n
*
k
+
m
*
n
)
return
ret
@
triton
.
jit
def
_compute_pid
(
tile_id
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
):
group_id
=
tile_id
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
tile_id
%
group_size_m
)
pid_n
=
(
tile_id
%
num_pid_in_group
)
//
group_size_m
return
pid_m
,
pid_n
@
triton
.
jit
(
launch_metadata
=
_matmul_launch_metadata
)
def
matmul_kernel_persistent
(
a_ptr
,
b_ptr
,
c_ptr
,
#
bias_ptr
,
M
,
N
,
K
,
#
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
#
BLOCK_SIZE_N
:
tl
.
constexpr
,
#
BLOCK_SIZE_K
:
tl
.
constexpr
,
#
GROUP_SIZE_M
:
tl
.
constexpr
,
#
NUM_SMS
:
tl
.
constexpr
,
#
A_LARGE
:
tl
.
constexpr
,
B_LARGE
:
tl
.
constexpr
,
C_LARGE
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
start_pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
num_tiles
=
num_pid_m
*
num_pid_n
tile_id_c
=
start_pid
-
NUM_SMS
offs_k_for_mask
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
for
tile_id
in
tl
.
range
(
start_pid
,
num_tiles
,
NUM_SMS
,
flatten
=
True
):
pid_m
,
pid_n
=
_compute_pid
(
tile_id
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
start_m
=
pid_m
*
BLOCK_SIZE_M
start_n
=
pid_n
*
BLOCK_SIZE_N
offs_am
=
start_m
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_bn
=
start_n
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
A_LARGE
:
offs_am
=
offs_am
.
to
(
tl
.
int64
)
if
B_LARGE
:
offs_bn
=
offs_bn
.
to
(
tl
.
int64
)
offs_am
=
tl
.
where
(
offs_am
<
M
,
offs_am
,
0
)
offs_bn
=
tl
.
where
(
offs_bn
<
N
,
offs_bn
,
0
)
offs_am
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_am
,
BLOCK_SIZE_M
),
BLOCK_SIZE_M
)
offs_bn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_bn
,
BLOCK_SIZE_N
),
BLOCK_SIZE_N
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
ki
in
range
(
k_tiles
):
if
A_LARGE
or
B_LARGE
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
).
to
(
tl
.
int64
)
else
:
offs_k
=
ki
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k_for_mask
[
None
,
:]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k_for_mask
[:,
None
]
<
K
-
ki
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
)
tile_id_c
+=
NUM_SMS
pid_m
,
pid_n
=
_compute_pid
(
tile_id_c
,
num_pid_in_group
,
num_pid_m
,
GROUP_SIZE_M
,
NUM_SMS
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
C_LARGE
:
offs_cm
=
offs_cm
.
to
(
tl
.
int64
)
offs_cn
=
offs_cn
.
to
(
tl
.
int64
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
HAS_BIAS
:
bias_ptrs
=
bias_ptr
+
offs_cn
bias
=
tl
.
load
(
bias_ptrs
,
mask
=
offs_cn
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
accumulator
+=
bias
if
c_ptr
.
dtype
.
element_ty
==
tl
.
float8e4nv
:
c
=
accumulator
.
to
(
tl
.
float8e4nv
)
else
:
c
=
accumulator
.
to
(
tl
.
float16
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
matmul_persistent
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
):
# Check constraints.
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"Incompatible dimensions"
assert
a
.
dtype
==
b
.
dtype
,
"Incompatible dtypes"
assert
(
bias
is
None
or
bias
.
dim
()
==
1
),
"Currently assuming bias is 1D, let Horace know if you run into this"
NUM_SMS
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
multi_processor_count
M
,
K
=
a
.
shape
K
,
N
=
b
.
shape
dtype
=
a
.
dtype
# Allocates output.
c
=
torch
.
empty
((
M
,
N
),
device
=
a
.
device
,
dtype
=
dtype
)
# 1D launch kernel where each block gets its own program.
def
grid
(
META
):
return
(
min
(
NUM_SMS
,
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
),
)
configs
=
{
torch
.
bfloat16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float16
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
torch
.
float32
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
3
,
"num_warps"
:
8
,
},
}
# print(a.device, b.device, c.device)
matmul_kernel_persistent
[
grid
](
a
,
b
,
c
,
#
bias
,
M
,
N
,
K
,
#
a
.
stride
(
0
),
a
.
stride
(
1
),
#
b
.
stride
(
0
),
b
.
stride
(
1
),
#
c
.
stride
(
0
),
c
.
stride
(
1
),
#
NUM_SMS
=
NUM_SMS
,
#
A_LARGE
=
a
.
numel
()
>
2
**
31
,
B_LARGE
=
b
.
numel
()
>
2
**
31
,
C_LARGE
=
c
.
numel
()
>
2
**
31
,
HAS_BIAS
=
bias
is
not
None
,
**
configs
[
dtype
],
)
return
c
@
triton
.
jit
def
_log_softmax_kernel
(
input_ptr
,
output_ptr
,
input_row_stride
,
output_row_stride
,
n_cols
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Compute log_softmax along the last dimension of a 2D tensor.
Each block handles one row of the input tensor.
"""
# Get the row index for this block
row_idx
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
# Compute base pointers for input and output rows
row_start_ptr
=
input_ptr
+
row_idx
*
input_row_stride
output_row_start_ptr
=
output_ptr
+
row_idx
*
output_row_stride
# Step 1: Find maximum value in the row for numerical stability
max_val
=
-
float
(
"inf"
)
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=-
float
(
"inf"
))
# Update maximum
max_val
=
tl
.
max
(
tl
.
maximum
(
vals
,
max_val
))
# Step 2: Compute sum of exp(x - max_val)
sum_exp
=
0.0
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
,
other
=
0.0
)
# Compute exp(x - max_val) and accumulate
exp_vals
=
tl
.
exp
(
vals
-
max_val
)
sum_exp
+=
tl
.
sum
(
tl
.
where
(
mask
,
exp_vals
,
0.0
))
# Compute log(sum_exp)
log_sum_exp
=
tl
.
log
(
sum_exp
)
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
for
col_offset
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
col_idx
=
col_offset
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
col_idx
<
n_cols
# Load values
vals
=
tl
.
load
(
row_start_ptr
+
col_idx
,
mask
=
mask
)
# Compute log_softmax
output
=
vals
-
max_val
-
log_sum_exp
# Store results
tl
.
store
(
output_row_start_ptr
+
col_idx
,
output
,
mask
=
mask
)
def
log_softmax
(
input
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Compute log_softmax using Triton kernel.
Args:
input: Input tensor
dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
>> Stashed changes
Returns:
Tensor with log_softmax applied along the specified dimension
"""
if
dim
!=
-
1
and
dim
!=
input
.
ndim
-
1
:
raise
ValueError
(
"This implementation only supports log_softmax along the last dimension"
)
# Flatten all dimensions except the last one
original_shape
=
input
.
shape
input_2d
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input_2d
.
contiguous
()
n_rows
,
n_cols
=
input_2d
.
shape
# Allocate output tensor
output
=
torch
.
empty_like
(
input_2d
)
# Choose block size based on the number of columns
BLOCK_SIZE
=
1024
# Launch kernel with one block per row
grid
=
(
n_rows
,)
_log_softmax_kernel
[
grid
](
input_2d
,
output
,
input_2d
.
stride
(
0
),
output
.
stride
(
0
),
n_cols
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
# Reshape output back to original shape
return
output
.
reshape
(
original_shape
)
@
triton
.
jit
def
mean_kernel
(
input_ptr
,
output_ptr
,
input_stride0
,
input_stride1
,
input_stride2
,
output_stride0
,
output_stride1
,
M
,
# size before reduction dim
N
,
# size of reduction dim
K
,
# size after reduction dim
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
Kernel for computing mean along a single dimension.
Input is viewed as (M, N, K) where N is the dimension being reduced.
"""
# Program ID gives us which output element we're computing
pid
=
tl
.
program_id
(
0
)
# Compute output indices
m_idx
=
pid
//
K
k_idx
=
pid
%
K
# Bounds check
if
m_idx
>=
M
or
k_idx
>=
K
:
return
# Accumulate sum across reduction dimension
acc
=
0.0
for
n_start
in
range
(
0
,
N
,
BLOCK_SIZE
):
n_offsets
=
n_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
n_offsets
<
N
# Calculate input indices
input_idx
=
(
m_idx
*
input_stride0
+
n_offsets
*
input_stride1
+
k_idx
*
input_stride2
)
# Load and accumulate
vals
=
tl
.
load
(
input_ptr
+
input_idx
,
mask
=
mask
,
other
=
0.0
)
acc
+=
tl
.
sum
(
vals
)
# Compute mean and store
mean_val
=
acc
/
N
output_idx
=
m_idx
*
output_stride0
+
k_idx
*
output_stride1
tl
.
store
(
output_ptr
+
output_idx
,
mean_val
)
def
mean_dim
(
input
:
torch
.
Tensor
,
dim
:
int
,
keepdim
:
bool
=
False
,
dtype
:
torch
.
dtype
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Triton implementation of torch.mean with single dimension reduction.
Args:
input: Input tensor
dim: Single dimension along which to compute mean
keepdim: Whether to keep the reduced dimension
dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
Returns:
Tensor with mean values along specified dimension
"""
# Validate inputs
assert
input
.
is_cuda
,
"Input must be a CUDA tensor"
assert
(
-
input
.
ndim
<=
dim
<
input
.
ndim
),
f
"Invalid dimension
{
dim
}
for tensor with
{
input
.
ndim
}
dimensions"
# Handle negative dim
if
dim
<
0
:
dim
=
dim
+
input
.
ndim
# Handle dtype
if
dtype
is
None
:
if
input
.
dtype
in
[
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
]:
dtype
=
torch
.
float32
else
:
dtype
=
input
.
dtype
# Convert input to appropriate dtype if needed
if
input
.
dtype
!=
dtype
:
input
=
input
.
to
(
dtype
)
# Get input shape and strides
shape
=
list
(
input
.
shape
)
# Calculate dimensions for kernel
M
=
1
for
i
in
range
(
dim
):
M
*=
shape
[
i
]
N
=
shape
[
dim
]
K
=
1
for
i
in
range
(
dim
+
1
,
len
(
shape
)):
K
*=
shape
[
i
]
# Reshape input to 3D view (M, N, K)
input_3d
=
input
.
reshape
(
M
,
N
,
K
)
# Create output shape
if
keepdim
:
output_shape
=
shape
.
copy
()
output_shape
[
dim
]
=
1
else
:
output_shape
=
shape
[:
dim
]
+
shape
[
dim
+
1
:]
# Create output tensor
output
=
torch
.
empty
(
output_shape
,
dtype
=
dtype
,
device
=
input
.
device
)
# Reshape output for kernel
if
keepdim
:
output_2d
=
output
.
reshape
(
M
,
1
,
K
).
squeeze
(
1
)
else
:
output_2d
=
output
.
reshape
(
M
,
K
)
# Launch kernel
grid
=
(
M
*
K
,)
BLOCK_SIZE
=
1024
mean_kernel
[
grid
](
input_3d
,
output_2d
,
input_3d
.
stride
(
0
),
input_3d
.
stride
(
1
),
input_3d
.
stride
(
2
),
output_2d
.
stride
(
0
),
output_2d
.
stride
(
1
)
if
output_2d
.
ndim
>
1
else
0
,
M
,
N
,
K
,
BLOCK_SIZE
,
)
return
output
def
mm_batch_invariant
(
a
,
b
):
return
matmul_persistent
(
a
,
b
)
def
addmm_batch_invariant
(
bias
,
a
,
b
):
return
matmul_persistent
(
a
,
b
,
bias
=
bias
)
def
_log_softmax_batch_invariant
(
input
,
dim
,
_half_to_float
):
assert
not
_half_to_float
,
"not implemented"
return
log_softmax
(
input
,
dim
=
dim
)
def
mean_batch_invariant
(
input
,
dim
,
keepdim
=
False
,
dtype
:
torch
.
dtype
|
None
=
None
):
assert
dtype
is
None
or
dtype
==
torch
.
float32
,
f
"unsupported dtype:
{
dtype
}
"
if
len
(
dim
)
==
1
:
return
mean_dim
(
input
,
dim
[
0
],
keepdim
=
keepdim
)
else
:
assert
input
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
},
"only float types supported for now"
n_elems
=
1
for
d
in
dim
:
n_elems
*=
input
.
shape
[
d
]
return
torch
.
sum
(
input
,
dim
=
dim
,
keepdim
=
keepdim
,
dtype
=
torch
.
float32
)
/
n_elems
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
def
is_batch_invariant_mode_enabled
():
return
_batch_invariant_MODE
def
enable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
if
_batch_invariant_MODE
:
return
_batch_invariant_MODE
=
True
_batch_invariant_LIB
=
torch
.
library
.
Library
(
"aten"
,
"IMPL"
)
_batch_invariant_LIB
.
impl
(
"aten::mm"
,
mm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::addmm"
,
addmm_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::_log_softmax"
,
_log_softmax_batch_invariant
,
"CUDA"
)
_batch_invariant_LIB
.
impl
(
"aten::mean.dim"
,
mean_batch_invariant
,
"CUDA"
)
def
disable_batch_invariant_mode
():
global
_batch_invariant_MODE
,
_batch_invariant_LIB
if
_batch_invariant_LIB
is
not
None
:
_batch_invariant_LIB
.
_destroy
()
_batch_invariant_MODE
=
False
_batch_invariant_LIB
=
None
@
contextlib
.
contextmanager
def
set_batch_invariant_mode
(
enabled
:
bool
=
True
):
global
_batch_invariant_MODE
,
_batch_invariant_LIB
old_data
=
(
_batch_invariant_MODE
,
_batch_invariant_LIB
)
if
enabled
:
enable_batch_invariant_mode
()
else
:
disable_batch_invariant_mode
()
yield
if
_batch_invariant_LIB
is
not
None
:
_batch_invariant_LIB
.
_destroy
()
_batch_invariant_MODE
,
_batch_invariant_LIB
=
old_data
AttentionBlockSize
=
namedtuple
(
"AttentionBlockSize"
,
[
"block_m"
,
"block_n"
])
def
get_batch_invariant_attention_block_size
()
->
AttentionBlockSize
:
return
AttentionBlockSize
(
block_m
=
16
,
block_n
=
16
)
python/sglang/srt/model_executor/model_runner.py
View file @
86527a47
...
@@ -408,7 +408,7 @@ class ModelRunner:
...
@@ -408,7 +408,7 @@ class ModelRunner:
# Enable batch invariant mode
# Enable batch invariant mode
if
server_args
.
enable_deterministic_inference
:
if
server_args
.
enable_deterministic_inference
:
from
batch_invariant_ops
import
enable_batch_invariant_mode
from
sglang.srt.
batch_invariant_ops
import
enable_batch_invariant_mode
enable_batch_invariant_mode
()
enable_batch_invariant_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment