Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
656
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
580 additions
and
137 deletions
+580
-137
vllm/transformers_utils/gguf_utils.py
vllm/transformers_utils/gguf_utils.py
+71
-0
vllm/transformers_utils/processor.py
vllm/transformers_utils/processor.py
+2
-1
vllm/transformers_utils/processors/hunyuan_vl.py
vllm/transformers_utils/processors/hunyuan_vl.py
+1
-1
vllm/transformers_utils/runai_utils.py
vllm/transformers_utils/runai_utils.py
+1
-3
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+0
-72
vllm/utils/argparse_utils.py
vllm/utils/argparse_utils.py
+1
-13
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+18
-0
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+14
-23
vllm/utils/hashing.py
vllm/utils/hashing.py
+36
-0
vllm/utils/nvtx_pytorch_hooks.py
vllm/utils/nvtx_pytorch_hooks.py
+286
-0
vllm/utils/serial_utils.py
vllm/utils/serial_utils.py
+49
-4
vllm/utils/system_utils.py
vllm/utils/system_utils.py
+4
-0
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+1
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-4
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+16
-15
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+75
-1
No files found.
Too many changes to show.
To preserve performance only
656 of 656+
files are displayed.
Plain diff
Email patch
vllm/transformers_utils/gguf_utils.py
View file @
8d75f22e
...
...
@@ -2,10 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""GGUF utility functions."""
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
import
gguf
import
regex
as
re
from
gguf.constants
import
Keys
,
VisionProjectorType
from
gguf.quants
import
GGMLQuantizationType
from
transformers
import
Gemma3Config
,
PretrainedConfig
,
SiglipVisionConfig
from
vllm.logger
import
init_logger
...
...
@@ -15,6 +19,73 @@ from .repo_utils import list_filtered_repo_files
logger
=
init_logger
(
__name__
)
@
cache
def
check_gguf_file
(
model
:
str
|
PathLike
)
->
bool
:
"""Check if the file is a GGUF model."""
model
=
Path
(
model
)
if
not
model
.
is_file
():
return
False
elif
model
.
suffix
==
".gguf"
:
return
True
try
:
with
model
.
open
(
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
except
Exception
as
e
:
logger
.
debug
(
"Error reading file %s: %s"
,
model
,
e
)
return
False
@
cache
def
is_remote_gguf
(
model
:
str
|
Path
)
->
bool
:
"""Check if the model is a remote GGUF model."""
pattern
=
r
"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*:[A-Za-z0-9_+-]+$"
model
=
str
(
model
)
if
re
.
fullmatch
(
pattern
,
model
):
_
,
quant_type
=
model
.
rsplit
(
":"
,
1
)
return
is_valid_gguf_quant_type
(
quant_type
)
return
False
def
is_valid_gguf_quant_type
(
gguf_quant_type
:
str
)
->
bool
:
"""Check if the quant type is a valid GGUF quant type."""
return
getattr
(
GGMLQuantizationType
,
gguf_quant_type
,
None
)
is
not
None
def
split_remote_gguf
(
model
:
str
|
Path
)
->
tuple
[
str
,
str
]:
"""Split the model into repo_id and quant type."""
model
=
str
(
model
)
if
is_remote_gguf
(
model
):
parts
=
model
.
rsplit
(
":"
,
1
)
return
(
parts
[
0
],
parts
[
1
])
raise
ValueError
(
f
"Wrong GGUF model or invalid GGUF quant type:
{
model
}
.
\n
"
"- It should be in repo_id:quant_type format.
\n
"
f
"- Valid GGMLQuantizationType values:
{
GGMLQuantizationType
.
_member_names_
}
"
,
)
def
is_gguf
(
model
:
str
|
Path
)
->
bool
:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model
=
str
(
model
)
# Check if it's a local GGUF file
if
check_gguf_file
(
model
):
return
True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return
is_remote_gguf
(
model
)
def
detect_gguf_multimodal
(
model
:
str
)
->
Path
|
None
:
"""Check if GGUF model has multimodal projector file.
...
...
vllm/transformers_utils/processor.py
View file @
8d75f22e
...
...
@@ -18,7 +18,8 @@ from transformers.processing_utils import ProcessorMixin
from
transformers.video_processing_utils
import
BaseVideoProcessor
from
typing_extensions
import
TypeVar
from
vllm.transformers_utils.utils
import
convert_model_repo_to_path
,
is_gguf
from
vllm.transformers_utils.gguf_utils
import
is_gguf
from
vllm.transformers_utils.utils
import
convert_model_repo_to_path
from
vllm.utils.func_utils
import
get_allowed_kwarg_only_overrides
if
TYPE_CHECKING
:
...
...
vllm/transformers_utils/processors/hunyuan_vl.py
View file @
8d75f22e
...
...
@@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin):
attention_mask
=
input_ids
.
ne
(
self
.
pad_id
)
text_inputs
[
"attention_mask"
]
=
attention_mask
text_inputs
[
"imgs_pos"
]
=
[
self
.
get_imgs_pos
(
input_ids
)
]
text_inputs
[
"imgs_pos"
]
=
[
self
.
get_imgs_pos
(
e
)
for
e
in
input_ids
]
# image_inputs["imgs"] = [[image_inputs["pixel_values"]]]
return_tensors
=
kwargs
.
pop
(
"return_tensors"
,
None
)
...
...
vllm/transformers_utils/runai_utils.py
View file @
8d75f22e
...
...
@@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"]
try
:
from
runai_model_streamer
import
list_safetensors
as
runai_list_safetensors
from
runai_model_streamer
import
pull_files
as
runai_pull_files
except
(
ImportError
,
OSError
):
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
except
ImportError
:
runai_model_streamer
=
PlaceholderModule
(
"runai_model_streamer"
)
# type: ignore[assignment]
runai_pull_files
=
runai_model_streamer
.
placeholder_attr
(
"pull_files"
)
runai_list_safetensors
=
runai_model_streamer
.
placeholder_attr
(
"list_safetensors"
)
...
...
vllm/transformers_utils/utils.py
View file @
8d75f22e
...
...
@@ -9,8 +9,6 @@ from os import PathLike
from
pathlib
import
Path
from
typing
import
Any
from
gguf
import
GGMLQuantizationType
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
...
...
@@ -29,76 +27,6 @@ def is_cloud_storage(model_or_path: str) -> bool:
return
is_s3
(
model_or_path
)
or
is_gcs
(
model_or_path
)
@
cache
def
check_gguf_file
(
model
:
str
|
PathLike
)
->
bool
:
"""Check if the file is a GGUF model."""
model
=
Path
(
model
)
if
not
model
.
is_file
():
return
False
elif
model
.
suffix
==
".gguf"
:
return
True
try
:
with
model
.
open
(
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
except
Exception
as
e
:
logger
.
debug
(
"Error reading file %s: %s"
,
model
,
e
)
return
False
@
cache
def
is_remote_gguf
(
model
:
str
|
Path
)
->
bool
:
"""Check if the model is a remote GGUF model."""
model
=
str
(
model
)
return
(
(
not
is_cloud_storage
(
model
))
and
(
not
model
.
startswith
((
"http://"
,
"https://"
)))
and
(
"/"
in
model
and
":"
in
model
)
and
is_valid_gguf_quant_type
(
model
.
rsplit
(
":"
,
1
)[
1
])
)
def
is_valid_gguf_quant_type
(
gguf_quant_type
:
str
)
->
bool
:
"""Check if the quant type is a valid GGUF quant type."""
return
getattr
(
GGMLQuantizationType
,
gguf_quant_type
,
None
)
is
not
None
def
split_remote_gguf
(
model
:
str
|
Path
)
->
tuple
[
str
,
str
]:
"""Split the model into repo_id and quant type."""
model
=
str
(
model
)
if
is_remote_gguf
(
model
):
parts
=
model
.
rsplit
(
":"
,
1
)
return
(
parts
[
0
],
parts
[
1
])
raise
ValueError
(
"Wrong GGUF model or invalid GGUF quant type: %s.
\n
"
"- It should be in repo_id:quant_type format.
\n
"
"- Valid GGMLQuantizationType values: %s"
,
model
,
GGMLQuantizationType
.
_member_names_
,
)
def
is_gguf
(
model
:
str
|
Path
)
->
bool
:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model
=
str
(
model
)
# Check if it's a local GGUF file
if
check_gguf_file
(
model
):
return
True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return
is_remote_gguf
(
model
)
def
modelscope_list_repo_files
(
repo_id
:
str
,
revision
:
str
|
None
=
None
,
...
...
vllm/utils/argparse_utils.py
View file @
8d75f22e
...
...
@@ -244,9 +244,8 @@ class FlexibleArgumentParser(ArgumentParser):
else
:
key
=
pattern
.
sub
(
repl
,
arg
,
count
=
1
)
processed_args
.
append
(
key
)
elif
arg
.
startswith
(
"-O"
)
and
arg
!=
"-O"
and
arg
[
2
]
!=
"."
:
elif
arg
.
startswith
(
"-O"
)
and
arg
!=
"-O"
:
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<optimization_level> here
optimization_level
=
arg
[
3
:]
if
arg
[
2
]
==
"="
else
arg
[
2
:]
processed_args
+=
[
"--optimization-level"
,
optimization_level
]
...
...
@@ -257,17 +256,6 @@ class FlexibleArgumentParser(ArgumentParser):
):
# Convert -O <n> to --optimization-level <n>
processed_args
.
append
(
"--optimization-level"
)
elif
arg
.
startswith
(
"-O."
):
# Handle -O.* dotted syntax - ALL dotted syntax is deprecated
logger
.
warning_once
(
"The -O.* dotted syntax for --compilation-config is "
"deprecated and will be removed in v0.13.0 or v1.0.0"
", whichever is earlier. Please use -cc.* instead. "
"Example: -cc.backend=eager instead of "
"-O.backend=eager."
)
converted_arg
=
arg
.
replace
(
"-O"
,
"-cc"
,
1
)
processed_args
.
append
(
converted_arg
)
else
:
processed_args
.
append
(
arg
)
...
...
vllm/utils/deep_gemm.py
View file @
8d75f22e
...
...
@@ -481,8 +481,25 @@ def should_use_deepgemm_for_fp8_linear(
)
def
should_use_deepgemm_for_fp8_linear_for_nk
(
output_dtype
:
torch
.
dtype
,
shape0
:
int
,
shape1
:
int
,
supports_deep_gemm
:
bool
|
None
=
None
,
):
if
supports_deep_gemm
is
None
:
supports_deep_gemm
=
is_deep_gemm_supported
()
return
(
supports_deep_gemm
and
output_dtype
==
torch
.
bfloat16
and
shape0
%
128
==
0
and
shape1
%
128
==
0
)
__all__
=
[
"calc_diff"
,
"DeepGemmQuantScaleFMT"
,
"fp8_gemm_nt"
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"fp8_m_grouped_gemm_nt_masked"
,
...
...
@@ -494,6 +511,7 @@ __all__ = [
"is_deep_gemm_supported"
,
"get_num_sms"
,
"should_use_deepgemm_for_fp8_linear"
,
"should_use_deepgemm_for_fp8_linear_for_nk"
,
"get_col_major_tma_aligned_tensor"
,
"get_mk_alignment_for_contiguous_layout"
,
]
vllm/utils/flashinfer.py
View file @
8d75f22e
...
...
@@ -267,21 +267,16 @@ def supports_trtllm_attention() -> bool:
return
current_platform
.
is_device_capability
(
100
)
and
has_nvidia_artifactory
()
@
functools
.
cache
def
_force_use_trtllm_attention
(
env_value
:
bool
|
None
)
->
bool
|
None
:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_ATTENTION is set to %s"
,
env_value
)
return
env_value
def
force_use_trtllm_attention
()
->
bool
|
None
:
"""
Return `None` if
VLLM_USE_TRTLLM_ATTENTION
is not set,
Return `None` if
--attention-config.use_trtllm_attention
is not set,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
return
_force_use_trtllm_attention
(
envs
.
VLLM_USE_TRTLLM_ATTENTION
)
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
return
vllm_config
.
attention_config
.
use_trtllm_attention
def
can_use_trtllm_attention
(
num_qo_heads
:
int
,
num_kv_heads
:
int
)
->
bool
:
...
...
@@ -307,7 +302,7 @@ def use_trtllm_attention(
"""Return `True` if TRTLLM attention is used."""
force_use_trtllm
=
force_use_trtllm_attention
()
#
Environment variable
is set to 0 - respect it
#
CLI argument
is set to 0 - respect it
if
force_use_trtllm
is
not
None
and
not
force_use_trtllm
:
return
False
...
...
@@ -324,7 +319,7 @@ def use_trtllm_attention(
if
force_use_trtllm
:
logger
.
warning_once
(
"TRTLLM attention is not supported on this platform, "
"but
VLLM_USE_TRTLLM_ATTENTION
is set to 1"
"but
--attention-config.use_trtllm_attention
is set to 1"
)
return
False
...
...
@@ -333,7 +328,8 @@ def use_trtllm_attention(
if
force_use_trtllm
:
logger
.
warning_once
(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
"query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
)
return
False
...
...
@@ -354,7 +350,7 @@ def use_trtllm_attention(
return
True
if
force_use_trtllm
is
None
:
#
Environment variable
not set - use auto-detection
#
CLI argument
not set - use auto-detection
if
is_prefill
:
# Prefill auto-detection
use_trtllm
=
kv_cache_dtype
==
"auto"
...
...
@@ -367,8 +363,10 @@ def use_trtllm_attention(
logger
.
warning_once
(
"Using TRTLLM decode attention (auto-detected)."
)
return
use_trtllm
# Environment variable is set to 1 - respect it
logger
.
info_once
(
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)"
)
# CLI argument is set to 1 - respect it
logger
.
info_once
(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return
True
...
...
@@ -500,12 +498,6 @@ def flashinfer_scaled_fp8_mm(
return
output
@
functools
.
cache
def
flashinfer_disable_q_quantization
()
->
bool
:
"""Cache result which only depends on the environment"""
return
envs
.
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__
=
[
"has_flashinfer"
,
"flashinfer_trtllm_fp8_block_scale_moe"
,
...
...
@@ -526,7 +518,6 @@ __all__ = [
"supports_trtllm_attention"
,
"can_use_trtllm_attention"
,
"use_trtllm_attention"
,
"flashinfer_disable_q_quantization"
,
"flashinfer_scaled_fp4_mm"
,
"flashinfer_scaled_fp8_mm"
,
]
vllm/utils/hashing.py
View file @
8d75f22e
...
...
@@ -11,6 +11,17 @@ from typing import Any
import
cbor2
try
:
# It is important that this remains an optional dependency.
# It would not be allowed in environments with strict security controls,
# so it's best not to have it installed when not in use.
import
xxhash
as
_xxhash
if
not
hasattr
(
_xxhash
,
"xxh3_128_digest"
):
_xxhash
=
None
except
ImportError
:
# pragma: no cover
_xxhash
=
None
def
sha256
(
input
:
Any
)
->
bytes
:
"""Hash any picklable Python object using SHA-256.
...
...
@@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
return
hashlib
.
sha256
(
input_bytes
).
digest
()
def
_xxhash_digest
(
input_bytes
:
bytes
)
->
bytes
:
if
_xxhash
is
None
:
raise
ModuleNotFoundError
(
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
"Install it via `pip install xxhash`."
)
return
_xxhash
.
xxh3_128_digest
(
input_bytes
)
def
xxhash
(
input
:
Any
)
->
bytes
:
"""Hash picklable objects using xxHash."""
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
_xxhash_digest
(
input_bytes
)
def
xxhash_cbor
(
input
:
Any
)
->
bytes
:
"""Hash objects serialized with CBOR using xxHash."""
input_bytes
=
cbor2
.
dumps
(
input
,
canonical
=
True
)
return
_xxhash_digest
(
input_bytes
)
def
get_hash_fn_by_name
(
hash_fn_name
:
str
)
->
Callable
[[
Any
],
bytes
]:
"""Get a hash function by name, or raise an error if the function is not found.
...
...
@@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return
sha256
if
hash_fn_name
==
"sha256_cbor"
:
return
sha256_cbor
if
hash_fn_name
==
"xxhash"
:
return
xxhash
if
hash_fn_name
==
"xxhash_cbor"
:
return
xxhash_cbor
raise
ValueError
(
f
"Unsupported hash function:
{
hash_fn_name
}
"
)
...
...
vllm/utils/nvtx_pytorch_hooks.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
contextmanager
import
torch
import
torch.cuda.nvtx
as
nvtx
def
print_tensor
(
tensor_obj
,
prefix
,
tensor_list
=
None
):
"""Descends iterators that contains Tensors and prints the Tensor.
Recursive function that descends iterator type arguments until
it finds a Tensor object.
"""
if
tensor_list
is
None
:
tensor_list
=
[]
if
isinstance
(
tensor_obj
,
(
list
,
tuple
)):
for
ten
in
tensor_obj
:
tensor_list
=
print_tensor
(
ten
,
prefix
,
tensor_list
)
elif
isinstance
(
tensor_obj
,
torch
.
Tensor
):
tensor_dims
=
list
(
tensor_obj
.
size
())
tensor_list
.
append
(
tensor_dims
)
return
tensor_list
def
process_layer_params
(
module_obj
):
"""Extract the static parameters from LLM and VLM relevant layer types"""
param_info
=
{}
# Extract parameters for layers commonly used in LLMs and VLMs
if
isinstance
(
module_obj
,
(
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
)):
conv_params
=
{}
conv_params
[
"in_chan"
]
=
module_obj
.
in_channels
conv_params
[
"out_chan"
]
=
module_obj
.
out_channels
conv_params
[
"filter_dim"
]
=
module_obj
.
kernel_size
conv_params
[
"stride"
]
=
module_obj
.
stride
conv_params
[
"padding"
]
=
module_obj
.
padding
conv_params
[
"dilation"
]
=
module_obj
.
dilation
conv_params
[
"transposed"
]
=
module_obj
.
transposed
conv_params
[
"output_padding"
]
=
module_obj
.
output_padding
conv_params
[
"groups"
]
=
module_obj
.
groups
conv_params
[
"padding_mode"
]
=
module_obj
.
padding_mode
param_info
=
conv_params
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
ConvTranspose1d
,
torch
.
nn
.
ConvTranspose2d
,
torch
.
nn
.
ConvTranspose3d
,
),
):
convtranspose_params
=
{}
convtranspose_params
[
"in_chan"
]
=
module_obj
.
in_channels
convtranspose_params
[
"out_chan"
]
=
module_obj
.
out_channels
convtranspose_params
[
"filter_dim"
]
=
module_obj
.
kernel_size
convtranspose_params
[
"stride"
]
=
module_obj
.
stride
convtranspose_params
[
"padding"
]
=
module_obj
.
padding
convtranspose_params
[
"dilation"
]
=
module_obj
.
dilation
convtranspose_params
[
"transposed"
]
=
module_obj
.
transposed
convtranspose_params
[
"output_padding"
]
=
module_obj
.
output_padding
convtranspose_params
[
"groups"
]
=
module_obj
.
groups
convtranspose_params
[
"padding_mode"
]
=
module_obj
.
padding_mode
param_info
=
convtranspose_params
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
MaxPool1d
,
torch
.
nn
.
MaxPool2d
,
torch
.
nn
.
MaxPool3d
)
):
def
_handle_int_or_tuple
(
parameter
):
if
isinstance
(
parameter
,
tuple
):
return
list
(
parameter
)
elif
isinstance
(
parameter
,
int
):
return
[
parameter
,
parameter
]
pooling_params
=
{}
pooling_params
[
"filter_dim"
]
=
_handle_int_or_tuple
(
module_obj
.
kernel_size
)
pooling_params
[
"stride"
]
=
_handle_int_or_tuple
(
module_obj
.
stride
)
pooling_params
[
"padding"
]
=
_handle_int_or_tuple
(
module_obj
.
padding
)
pooling_params
[
"dilation"
]
=
_handle_int_or_tuple
(
module_obj
.
dilation
)
param_info
=
pooling_params
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
AvgPool1d
,
torch
.
nn
.
AvgPool2d
,
torch
.
nn
.
AvgPool3d
)
):
pooling_params
=
{}
pooling_params
[
"filter_dim"
]
=
[
module_obj
.
kernel_size
,
module_obj
.
kernel_size
,
]
pooling_params
[
"stride"
]
=
[
module_obj
.
stride
,
module_obj
.
stride
]
pooling_params
[
"padding"
]
=
[
module_obj
.
padding
,
module_obj
.
padding
]
pooling_params
[
"ceil_mode"
]
=
module_obj
.
ceil_mode
pooling_params
[
"count_include_pad"
]
=
module_obj
.
count_include_pad
param_info
=
pooling_params
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
AdaptiveAvgPool1d
,
torch
.
nn
.
AdaptiveAvgPool2d
,
torch
.
nn
.
AdaptiveAvgPool3d
,
),
):
pooling_params
=
{}
pooling_params
[
"output_size"
]
=
[
module_obj
.
output_size
,
module_obj
.
output_size
,
]
param_info
=
pooling_params
elif
isinstance
(
module_obj
,
torch
.
nn
.
Linear
):
param_info
[
"in_features"
]
=
module_obj
.
in_features
param_info
[
"out_features"
]
=
module_obj
.
out_features
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
),
):
param_info
[
"num_features"
]
=
module_obj
.
num_features
param_info
[
"epsilon"
]
=
module_obj
.
eps
param_info
[
"momentum"
]
=
module_obj
.
momentum
elif
isinstance
(
module_obj
,
torch
.
nn
.
ReLU
):
param_info
[
"in_place"
]
=
module_obj
.
inplace
elif
isinstance
(
module_obj
,
torch
.
nn
.
Dropout
):
param_info
[
"p"
]
=
module_obj
.
p
param_info
[
"in_place"
]
=
module_obj
.
inplace
elif
isinstance
(
module_obj
,
torch
.
nn
.
Embedding
):
param_info
[
"num_embeddings"
]
=
module_obj
.
num_embeddings
param_info
[
"embedding_dim"
]
=
module_obj
.
embedding_dim
elif
isinstance
(
module_obj
,
(
torch
.
nn
.
Upsample
,
torch
.
nn
.
UpsamplingNearest2d
,
torch
.
nn
.
UpsamplingBilinear2d
,
),
):
param_info
[
"scale_factor"
]
=
module_obj
.
scale_factor
return
param_info
def
construct_marker_dict_and_push
(
module_name
,
module_obj
,
in_tensor
,
kwargs
=
None
,
out_tensor
=
None
):
marker_dict
=
{}
marker_dict
[
"Module"
]
=
module_name
## Get trainable parameters like weights and bias
module_params
=
module_obj
.
named_parameters
(
recurse
=
False
)
for
idx
,
(
param_name
,
param_obj
)
in
enumerate
(
module_params
):
if
idx
==
0
:
marker_dict
[
"TrainableParams"
]
=
{}
marker_dict
[
"TrainableParams"
][
param_name
]
=
list
(
param_obj
.
size
())
in_tensor_list
=
print_tensor
(
in_tensor
,
"Input"
)
if
in_tensor_list
:
marker_dict
[
"Inputs"
]
=
in_tensor_list
out_tensor_list
=
print_tensor
(
out_tensor
,
"Output"
)
if
out_tensor_list
:
marker_dict
[
"Outputs"
]
=
out_tensor_list
## Get Kwargs like input_ids and positions for the top module
if
kwargs
:
for
key
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
(
torch
.
Tensor
,
list
,
tuple
)):
tensor_list
=
print_tensor
(
value
,
key
)
if
tensor_list
:
marker_dict
[
key
]
=
tensor_list
param_info
=
process_layer_params
(
module_obj
)
if
param_info
:
marker_dict
[
"StaticParams"
]
=
param_info
nvtx
.
range_push
(
"{}"
.
format
(
marker_dict
))
class
ResultHolder
:
"""Holder for storing results from within a context manager."""
result
=
None
@
contextmanager
def
layerwise_nvtx_marker_context
(
module_name
,
module_obj
,
in_tensor
=
None
,
kwargs
=
None
):
"""Context manager for NVTX markers that automatically pushes on enter
and pops on exit.
Example:
with nvtx_marker_context("Module:MyModule", module, in_tensor=args,
kwargs=kwargs) as ctx:
ctx.result = module(*args, **kwargs)
return ctx.result
"""
holder
=
ResultHolder
()
# Push input marker
construct_marker_dict_and_push
(
module_name
,
module_obj
,
in_tensor
=
in_tensor
,
kwargs
=
kwargs
,
)
try
:
yield
holder
finally
:
# Pop input marker
nvtx
.
range_pop
()
# Push and pop output marker
output_name
=
module_name
.
replace
(
"(input)"
,
"(output)"
)
construct_marker_dict_and_push
(
output_name
,
module_obj
,
in_tensor
=
None
,
kwargs
=
None
,
out_tensor
=
holder
.
result
,
)
nvtx
.
range_pop
()
class
PytHooks
:
"""This module contains all the code needed to enable forward hooks
in a pytorch network.
To register the hooks for a given network, the user needs to instantiate
a PytHook object. Then call the register_hooks method.
Example:
my_hook = PytHook()
my_hook.register_hooks(my_network_model)
"""
def
__init__
(
self
):
"""Initialize module variables."""
super
().
__init__
()
self
.
module_to_name_map
=
{}
def
_process_layer_params
(
self
,
module_obj
):
return
process_layer_params
(
module_obj
)
def
module_fwd_hook
(
self
,
module_obj
,
in_tensor
,
out_tensor
):
"""Callback function that ends the NVTX marker.
Records the module name and tensor information.
Called after the module executes the forward method.
"""
nvtx
.
range_pop
()
module_name
=
self
.
module_to_name_map
.
get
(
module_obj
,
"unknown"
)
construct_marker_dict_and_push
(
module_name
,
module_obj
,
in_tensor
=
None
,
kwargs
=
None
,
out_tensor
=
out_tensor
)
nvtx
.
range_pop
()
return
def
module_fwd_pre_hook
(
self
,
module_obj
,
in_tensor
,
kwargs
):
"""Creates an NVTX marker with the module name in it.
This function is called before the module executes.
"""
module_name
=
self
.
module_to_name_map
.
get
(
module_obj
,
"unknown"
)
construct_marker_dict_and_push
(
module_name
,
module_obj
,
in_tensor
=
in_tensor
,
kwargs
=
kwargs
,
out_tensor
=
None
)
return
def
register_hooks
(
self
,
network_model
,
module_prefix
=
"top"
):
"""User level function that activates all the hooks.
The user needs to call this method from the network source code.
The code descends all the modules in the network and registers their
respective hooks.
"""
# Module types to skip (simple operations that don't need detailed profiling)
skip_types
=
(
torch
.
nn
.
Identity
,
torch
.
nn
.
Dropout
,
torch
.
nn
.
Dropout1d
,
torch
.
nn
.
Dropout2d
,
torch
.
nn
.
Dropout3d
,
)
for
name
,
module
in
network_model
.
named_modules
(
prefix
=
module_prefix
):
# Skip certain module types to reduce profiling overhead
if
isinstance
(
module
,
skip_types
):
continue
module
.
register_forward_pre_hook
(
self
.
module_fwd_pre_hook
,
with_kwargs
=
True
)
module
.
register_forward_hook
(
self
.
module_fwd_hook
)
if
module
not
in
self
.
module_to_name_map
:
self
.
module_to_name_map
[
module
]
=
name
else
:
raise
ValueError
(
"Module instance {} is not unique "
.
format
(
module
))
return
vllm/utils/serial_utils.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
io
import
math
import
sys
from
dataclasses
import
dataclass
from
typing
import
Literal
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
import
numpy
as
np
import
torch
from
typing_extensions
import
assert_never
from
vllm
import
PoolingRequestOutput
if
TYPE_CHECKING
:
from
vllm
import
PoolingRequestOutput
else
:
PoolingRequestOutput
=
Any
sys_byteorder
=
sys
.
byteorder
...
...
@@ -26,6 +31,14 @@ EMBED_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2"
:
torch
.
float8_e5m2
,
}
EMBED_DTYPE_TO_N_BYTES
=
{
"float32"
:
4
,
"float16"
:
2
,
"bfloat16"
:
2
,
"fp8_e4m3"
:
1
,
"fp8_e5m2"
:
1
,
}
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW
=
{
"float32"
:
torch
.
float32
,
...
...
@@ -49,7 +62,16 @@ ENDIANNESS = ["native", "big", "little"]
EmbedDType
=
Literal
[
"float32"
,
"float16"
,
"bfloat16"
,
"fp8_e4m3"
,
"fp8_e5m2"
]
Endianness
=
Literal
[
"native"
,
"big"
,
"little"
]
EncodingFormat
=
Literal
[
"float"
,
"base64"
,
"bytes"
]
EncodingFormat
=
Literal
[
"float"
,
"base64"
,
"bytes"
,
"bytes_only"
]
def
tensor2base64
(
x
:
torch
.
Tensor
)
->
str
:
with
io
.
BytesIO
()
as
buf
:
torch
.
save
(
x
,
buf
)
buf
.
seek
(
0
)
binary_data
=
buf
.
read
()
return
base64
.
b64encode
(
binary_data
).
decode
(
"utf-8"
)
def
tensor2binary
(
...
...
@@ -104,7 +126,7 @@ def encode_pooling_output(
elif
encoding_format
==
"base64"
:
embedding_bytes
=
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
return
base64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
elif
encoding_format
==
"bytes"
:
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
assert_never
(
encoding_format
)
...
...
@@ -119,6 +141,29 @@ class MetadataItem:
shape
:
tuple
[
int
,
...]
def
build_metadata_items
(
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
shape
:
tuple
[
int
,
...],
n_request
:
int
,
):
n_bytes
=
EMBED_DTYPE_TO_N_BYTES
[
embed_dtype
]
size
=
math
.
prod
(
shape
)
items
=
[
MetadataItem
(
index
=
i
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
start
=
i
*
size
*
n_bytes
,
end
=
(
i
+
1
)
*
size
*
n_bytes
,
shape
=
shape
,
)
for
i
in
range
(
n_request
)
]
return
items
def
encode_pooling_bytes
(
pooling_outputs
:
list
[
PoolingRequestOutput
],
embed_dtype
:
EmbedDType
,
...
...
vllm/utils/system_utils.py
View file @
8d75f22e
...
...
@@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def
decorate_logs
(
process_name
:
str
|
None
=
None
)
->
None
:
"""Decorate stdout/stderr with process name and PID prefix."""
# Respect VLLM_CONFIGURE_LOGGING environment variable
if
not
envs
.
VLLM_CONFIGURE_LOGGING
:
return
if
process_name
is
None
:
process_name
=
get_mp_context
().
current_process
().
name
...
...
vllm/utils/torch_utils.py
View file @
8d75f22e
...
...
@@ -28,6 +28,7 @@ else:
STR_DTYPE_TO_TORCH_DTYPE
=
{
"float32"
:
torch
.
float32
,
"half"
:
torch
.
half
,
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"fp8"
:
torch
.
uint8
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
8d75f22e
...
...
@@ -8,7 +8,6 @@ from typing import ClassVar
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
...
@@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
attention_config
=
vllm_config
.
attention_config
self
.
num_heads_q
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
...
...
@@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self
.
max_num_splits
=
envs
.
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self
.
max_num_splits
=
(
self
.
attention_config
.
flash_attn_max_num_splits_for_cuda_graph
)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
...
...
@@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl):
"heads in the layer"
)
def
supports_quant_query_input
(
self
)
->
bool
:
return
True
self
.
supports_quant_query_input
=
True
def
forward
(
self
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
8d75f22e
...
...
@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
)
from
vllm.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_current_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.distributed.parallel_state
import
get_dcp_group
from
vllm.logger
import
init_logger
...
...
@@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.flashinfer
import
(
can_use_trtllm_attention
,
flashinfer_disable_q_quantization
,
use_trtllm_attention
,
)
from
vllm.utils.math_utils
import
cdiv
...
...
@@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend):
supports_trtllm_attention
,
)
# Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
# Respect explicit disable flag (e.g.,
# --attention-config.use_trtllm_attention=0)
if
force_use_trtllm_attention
()
is
False
:
return
False
...
...
@@ -482,9 +482,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
dcp_rank
=
0
self
.
dcp_kv_cache_interleave_size
=
1
self
.
num_qo_heads
=
(
self
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
)
*
self
.
dcp_world_size
self
.
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
)
self
.
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
...
...
@@ -501,11 +500,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
kv_cache_dtype
=
self
.
kv_cache_spec
.
dtype
# Use model dtype as q dtype when TRTLLM attn is not supported, or
#
VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
is set to 1. Otherwise,
try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
#
--attention-config.disable_flashinfer_q_quantization
is set to 1. Otherwise,
#
try to
use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm
=
can_use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_kv_heads
)
if
can_use_trtllm
and
not
flashinfer_disable_q_quantization
():
if
(
can_use_trtllm
and
not
vllm_config
.
attention_config
.
disable_flashinfer_q_quantization
):
self
.
q_data_type
=
self
.
kv_cache_dtype
else
:
self
.
q_data_type
=
self
.
model_config
.
dtype
...
...
@@ -1036,6 +1038,11 @@ class FlashInferImpl(AttentionImpl):
self
.
sinks
=
sinks
self
.
support_trtllm_attn
=
can_use_trtllm_attention
(
num_heads
,
num_kv_heads
)
vllm_config
=
get_current_vllm_config
()
self
.
supports_quant_query_input
=
(
self
.
support_trtllm_attn
and
not
vllm_config
.
attention_config
.
disable_flashinfer_q_quantization
)
self
.
bmm1_scale
:
float
|
None
=
None
self
.
bmm2_scale
:
float
|
None
=
None
self
.
o_sf_scale
:
float
|
None
=
None
...
...
@@ -1047,12 +1054,6 @@ class FlashInferImpl(AttentionImpl):
and
quant_key
in
(
kFp8StaticTensorSym
,
kNvfp4Quant
)
)
def
supports_quant_query_input
(
self
)
->
bool
:
if
flashinfer_disable_q_quantization
():
return
False
return
self
.
support_trtllm_attn
# FlashInfer requires attention sinks to be float32
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
if
self
.
sinks
is
not
None
and
self
.
sinks
.
dtype
!=
torch
.
float32
:
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
8d75f22e
...
...
@@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
and_masks
,
create_block_mask
,
flex_attention
,
or_masks
,
)
from
vllm.attention.backends.abstract
import
(
...
...
@@ -31,6 +32,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
...
...
@@ -41,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger
=
init_logger
(
__name__
)
torch
.
_dynamo
.
config
.
recompile_limit
=
16
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
fullgraph
=
True
,
mode
=
"reduce-overhead"
)
...
...
@@ -90,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
"""FlexAttention supports both decoder and encoder-only attention."""
return
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_ONLY
)
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
"""FlexAttention supports full attention for image tokens."""
return
True
@
staticmethod
def
get_impl_cls
()
->
type
[
"FlexAttentionImpl"
]:
return
FlexAttentionImpl
...
...
@@ -315,6 +323,7 @@ class FlexAttentionMetadata:
kv_block_size
:
int
=
16
transformed_score_mod
:
_score_mod_signature
|
None
=
None
sliding_window
:
int
|
None
=
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
@
cached_property
def
logical_block_ids
(
self
):
...
...
@@ -442,6 +451,45 @@ class FlexAttentionMetadata:
return
final_mask_mod
if
self
.
causal
else
sliding_window_mask_mod
def
get_prefix_lm_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the prefix LM mask_mod function for FlexAttention."""
assert
self
.
doc_ids
is
not
None
request_lookup
=
self
.
doc_ids
def
prefix_lm_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
cu_q_idx
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
):
mask
=
torch
.
zeros_like
(
q_idx
,
dtype
=
torch
.
bool
)
for
req
,
doc_range_lst
in
(
self
.
mm_prefix_range
or
{}).
items
():
req_mask
=
request_lookup
[
cu_q_idx
]
==
req
for
start
,
end
in
doc_range_lst
:
doc_mask_q
=
(
q_idx
>=
start
)
&
(
q_idx
<=
end
)
doc_mask_kv
=
(
kv_idx
>=
start
)
&
(
kv_idx
<=
end
)
mask
=
mask
|
(
req_mask
&
doc_mask_q
&
doc_mask_kv
)
return
mask
def
final_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
physical_kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
is_valid
,
logical_q_idx
,
logical_kv_idx
)
=
(
self
.
_convert_physical_to_logical
(
self
.
doc_ids
,
q_idx
,
physical_kv_idx
)
)
return
torch
.
where
(
is_valid
,
prefix_lm_mask_mod
(
b
,
h
,
q_idx
,
logical_q_idx
,
logical_kv_idx
),
False
,
)
return
final_mask_mod
def
get_mask_mod
(
self
):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
...
...
@@ -455,6 +503,10 @@ class FlexAttentionMetadata:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod
=
self
.
get_sliding_window_mask_mod
()
mask_mod
=
and_masks
(
mask_mod
,
sliding_window_mask_mod
)
if
self
.
mm_prefix_range
:
# Add prefix LM mask for vision-language prefix LM attention
prefix_lm_mask_mod
=
self
.
get_prefix_lm_mask_mod
()
mask_mod
=
or_masks
(
mask_mod
,
prefix_lm_mask_mod
)
return
mask_mod
def
get_transformed_score_mod
(
self
)
->
_score_mod_signature
|
None
:
...
...
@@ -708,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
sliding_window
:
int
|
None
alibi_slopes
:
torch
.
Tensor
|
None
logits_soft_cap
:
float
|
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
def
__init__
(
self
,
...
...
@@ -809,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
needs_rebuild_block_mask
=
False
if
attn_metadata
.
sliding_window
!=
self
.
sliding_window
:
attn_metadata
.
sliding_window
=
self
.
sliding_window
if
attn_metadata
.
direct_build
:
# update mask mod in attention metadata
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
needs_rebuild_block_mask
=
True
if
self
.
mm_prefix_range
!=
getattr
(
attn_metadata
,
"mm_prefix_range"
,
None
):
self
.
mm_prefix_range
=
attn_metadata
.
mm_prefix_range
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
needs_rebuild_block_mask
=
True
if
needs_rebuild_block_mask
:
if
attn_metadata
.
direct_build
and
attn_metadata
.
causal
:
attn_metadata
.
block_mask
=
attn_metadata
.
_build_block_mask_direct
()
else
:
attn_metadata
.
block_mask
=
attn_metadata
.
build_block_mask
()
...
...
@@ -927,7 +990,18 @@ def get_kernel_options(
if
torch
.
cuda
.
is_available
():
device_props
=
torch
.
cuda
.
get_device_properties
()
max_shared_memory
=
device_props
.
shared_memory_per_block_optin
# ROCm doesn't expose shared_memory_per_block_optin attribute
# AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup
if
hasattr
(
device_props
,
"shared_memory_per_block_optin"
):
max_shared_memory
=
device_props
.
shared_memory_per_block_optin
elif
current_platform
.
is_rocm
():
# ROCm fallback: use 64KB
max_shared_memory
=
65536
else
:
raise
RuntimeError
(
"Unable to determine shared memory size on this hardware."
)
if
max_shared_memory
<
144
*
1024
:
block_m_candidate
=
ensure_divisible
(
max
(
1
,
block_m_candidate
//
2
),
block_m
...
...
Prev
1
…
29
30
31
32
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment