Commit b3230e1a authored by Yongye Zhu's avatar Yongye Zhu Committed by simon-mo
Browse files
parent 03df0fb5
...@@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
......
...@@ -93,11 +93,14 @@ class CpuPlatform(Platform): ...@@ -93,11 +93,14 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")
......
...@@ -129,6 +129,8 @@ class CudaPlatformBase(Platform): ...@@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
# TODO(lucas): handle this more gracefully # TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing # Note: model_config may be None during testing
if model_config is not None and model_config.use_mla: if model_config is not None and model_config.use_mla:
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs, # then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the # else we default to CutlassMLA. For each case, we force the
...@@ -175,6 +177,12 @@ class CudaPlatformBase(Platform): ...@@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashInferMLA " "Forcing kv cache block size to 64 for FlashInferMLA "
"backend.") "backend.")
# TODO(Chen): remove this hacky code
if use_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse "
"backend.")
# lazy import to avoid circular import # lazy import to avoid circular import
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
...@@ -231,7 +239,7 @@ class CudaPlatformBase(Platform): ...@@ -231,7 +239,7 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(
...@@ -241,6 +249,11 @@ class CudaPlatformBase(Platform): ...@@ -241,6 +249,11 @@ class CudaPlatformBase(Platform):
from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
"FlashMLASparseBackend")
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100) selected_backend is None and cls.is_device_capability(100)
and block_size == 128) and block_size == 128)
......
...@@ -194,7 +194,7 @@ class Platform: ...@@ -194,7 +194,7 @@ class Platform:
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""
......
...@@ -195,7 +195,10 @@ class RocmPlatform(Platform): ...@@ -195,7 +195,10 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(
......
...@@ -49,7 +49,10 @@ class TpuPlatform(Platform): ...@@ -49,7 +49,10 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS: if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)
......
...@@ -36,7 +36,10 @@ class XPUPlatform(Platform): ...@@ -36,7 +36,10 @@ class XPUPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str: has_sink: bool, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on XPU.")
use_v1 = envs.VLLM_USE_V1 use_v1 = envs.VLLM_USE_V1
if not use_v1: if not use_v1:
raise ValueError("XPU backend only supports V1.") raise ValueError("XPU backend only supports V1.")
......
...@@ -66,6 +66,8 @@ class LazyConfigDict(dict): ...@@ -66,6 +66,8 @@ class LazyConfigDict(dict):
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
chatglm="ChatGLMConfig", chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config", deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
kimi_vl="KimiVLConfig", kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
......
...@@ -8,6 +8,7 @@ Model configs may be defined in this directory for the following reasons: ...@@ -8,6 +8,7 @@ Model configs may be defined in this directory for the following reasons:
""" """
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig
...@@ -37,6 +38,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig ...@@ -37,6 +38,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [ __all__ = [
"ChatGLMConfig", "ChatGLMConfig",
"DeepseekVLV2Config", "DeepseekVLV2Config",
"DeepseekV3Config",
"DotsOCRConfig", "DotsOCRConfig",
"EAGLEConfig", "EAGLEConfig",
"RWConfig", "RWConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DeepseekV3Config(PretrainedConfig):
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
ep_size=1,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method='noaux_tc',
n_group=8,
topk_group=4,
num_experts_per_tok=8,
moe_layer_freq=1,
first_k_dense_replace=3,
norm_topk_prob=True,
scoring_func='sigmoid',
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
...@@ -130,6 +130,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -130,6 +130,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
} }
TORCH_DTYPE_TO_NUMPY_DTYPE = { TORCH_DTYPE_TO_NUMPY_DTYPE = {
...@@ -3433,6 +3434,12 @@ def has_triton_kernels() -> bool: ...@@ -3433,6 +3434,12 @@ def has_triton_kernels() -> bool:
return _has_module("triton_kernels") return _has_module("triton_kernels")
def has_tilelang() -> bool:
"""Whether the optional `tilelang` package is available."""
return _has_module("tilelang")
def set_process_title(name: str, def set_process_title(name: str,
suffix: str = "", suffix: str = "",
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:
......
...@@ -70,17 +70,25 @@ def _missing(*_: Any, **__: Any) -> NoReturn: ...@@ -70,17 +70,25 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
_fp8_gemm_nt_impl: Callable[..., Any] | None = None _fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
def _lazy_init() -> None: def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use.""" """Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\ global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
_get_mn_major_tma_aligned_tensor_impl global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
global _get_paged_mqa_logits_metadata_impl
global _get_mn_major_tma_aligned_tensor_impl
# fast path # fast path
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
or _grouped_masked_impl is not None): or _grouped_masked_impl is not None
or _fp8_mqa_logits_impl is not None
or _fp8_paged_mqa_logits_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None):
return return
if not has_deep_gemm(): if not has_deep_gemm():
...@@ -97,10 +105,20 @@ def _lazy_init() -> None: ...@@ -97,10 +105,20 @@ def _lazy_init() -> None:
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
_get_paged_mqa_logits_metadata_impl = getattr(
_dg, "get_paged_mqa_logits_metadata", None)
_get_mn_major_tma_aligned_tensor_impl = getattr( _get_mn_major_tma_aligned_tensor_impl = getattr(
_dg, "get_mn_major_tma_aligned_tensor", None) _dg, "get_mn_major_tma_aligned_tensor", None)
def get_num_sms() -> int:
_lazy_init()
_dg = importlib.import_module("deep_gemm")
return int(_dg.get_num_sms())
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
_lazy_init() _lazy_init()
...@@ -135,6 +153,100 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): ...@@ -135,6 +153,100 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
def fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
_lazy_init()
if _fp8_mqa_logits_impl is None:
return _missing()
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int,
num_sms: int) -> torch.Tensor:
"""Build scheduling metadata for paged MQA logits.
Args:
context_lens: Tensor of shape [B], dtype int32; effective context length
per batch element.
block_size: KV-cache block size in tokens (e.g., 64).
num_sms: Number of SMs available. 132 for Hopper
Returns:
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
schedule work across SMs.
"""
_lazy_init()
if _get_paged_mqa_logits_metadata_impl is None:
return _missing()
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size,
num_sms)
def fp8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
_lazy_init()
if _fp8_paged_mqa_logits_impl is None:
return _missing()
return _fp8_paged_mqa_logits_impl(q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=True)
def _ceil_to_ue8m0(x: torch.Tensor): def _ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
...@@ -195,9 +307,13 @@ __all__ = [ ...@@ -195,9 +307,13 @@ __all__ = [
"fp8_gemm_nt", "fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked", "fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_paged_mqa_logits",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8", "per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used", "is_deep_gemm_e8m0_used",
"is_deep_gemm_supported", "is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear", "should_use_deepgemm_for_fp8_linear",
"get_col_major_tma_aligned_tensor", "get_col_major_tma_aligned_tensor",
] ]
\ No newline at end of file
...@@ -74,6 +74,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -74,6 +74,7 @@ class TorchSDPABackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return _get_paged_attn_impl().get_kv_cache_shape( return _get_paged_attn_impl().get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size) num_blocks, block_size, num_kv_heads, head_size)
......
...@@ -80,6 +80,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -80,6 +80,7 @@ class FlashAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
......
...@@ -187,6 +187,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -187,6 +187,7 @@ class FlashInferBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
......
...@@ -88,6 +88,7 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -88,6 +88,7 @@ class FlexAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
......
...@@ -286,6 +286,7 @@ class MLACommonBackend(AttentionBackend): ...@@ -286,6 +286,7 @@ class MLACommonBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA num_kv_heads: int, # assumed to be 1 for MLA
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, block_size, head_size) return (num_blocks, block_size, head_size)
...@@ -407,6 +408,7 @@ class MLACommonMetadata(Generic[D]): ...@@ -407,6 +408,7 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
A = TypeVar("A")
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
...@@ -930,7 +932,9 @@ def reorg_kvcache( ...@@ -930,7 +932,9 @@ def reorg_kvcache(
return reorganized_kv_c_normed, reorganized_k_pe return reorganized_kv_c_normed, reorganized_k_pe
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
...@@ -956,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -956,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int, qk_head_dim: int,
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear,
indexer=None,
q_pad_num_heads: Optional[int] = None, q_pad_num_heads: Optional[int] = None,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
...@@ -974,8 +979,140 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -974,8 +979,140 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.qk_head_dim = qk_head_dim self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads self.q_pad_num_heads = q_pad_num_heads
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if is_rocm_aiter_fp8bmm_enabled():
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
W_K, dtype=current_platform.fp8_dtype())
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
W_V, dtype=current_platform.fp8_dtype())
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
pre_compilation_list = list(range(1, max_batch_size + 1))
if is_global_first_rank():
pre_compilation_list = tqdm(
pre_compilation_list,
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
total=max_batch_size,
)
for m in pre_compilation_list:
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
dtype=torch.bfloat16,
device=self.W_K.device)
aiter_triton_fp8_bmm(x,
self.W_K,
self.W_K_scale,
group_size=128,
transpose_bm=True)
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
dtype=torch.bfloat16,
device=self.W_V.device)
aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
else:
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(
-1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if use_flashinfer_prefill(): if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA") logger.debug_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
...@@ -1154,36 +1291,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1154,36 +1291,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
True, #Indicates actual_seq_lens are on GPU or CPU. True, #Indicates actual_seq_lens are on GPU or CPU.
) )
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(x,
self.W_V,
self.W_V_scale,
group_size=128,
transpose_bm=True)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new = out.transpose(0, 1).reshape(
-1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer): def get_layer_weight(layer):
...@@ -1455,6 +1562,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1455,6 +1562,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor, k_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None assert self.dcp_world_size is not None
......
...@@ -177,6 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -177,6 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO: (zyongye) decode function for mla here
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return DeepseekV32IndexerMetadata
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 128]
@staticmethod
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
return DeepseekV32IndexerMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
assert num_kv_heads == 1
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2)
@dataclass
class DeepseekV32IndexerPrefillMetadata:
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
@dataclass
class DeepseekV32IndexerMetadata:
# FIXME (zyongye)
# hacky way to access the data now, need to be in chunked meta
seq_lens: torch.Tensor
num_reqs: int
max_query_len: int
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
head_dim: int
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor,
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
selected tokens per batch.
Example: [0, 2, 4, 7] ->
batch sizes (selected) [2, 2, 3], N=7 tokens total.
seq_len_per_batch: 1D long tensor [B],
full sequence length (KV length) of each batch.
Example: [5, 9, 4].
Returns:
start_tensor: 1D long tensor [N], start offset in the
concatenated KV cache for each token's batch.
end_location: 1D long tensor [N],
**exclusive** end = start + token's local position.
(So the attended KV slice is kv[start:end].)
Assumes each batch contributes its full `seq_len_per_batch[i]`
keys to the KV cache, andthe selected tokens within a batch
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
# Selected tokens per batch and totals
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
device = L.device
if N == 0:
return (torch.empty(0, dtype=torch.long, device=device),
torch.empty(0, dtype=torch.long, device=device))
# KV start offsets per batch in the concatenated KV cache
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B, device=device),
counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
# End-align local positions inside each batch:
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (torch.arange(N, device=device, dtype=torch.long) -
torch.repeat_interleave(q[:-1], counts) + 1)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int(), end_location.int()
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# max_num_batched_tokens = \
# vllm_config.scheduler_config.max_num_batched_tokens
max_num_seq = vllm_config.scheduler_config.max_num_seqs
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
return max_model_len * max_num_seq
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
scheduler_config = self.vllm_config.scheduler_config
#NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
self.max_prefill_buffer_size = get_max_prefill_buffer_size(
self.vllm_config)
self.num_speculative_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
self.num_sms = sm_count
self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs, ),
dtype=torch.int32,
device=self.device)
# See: DeepGMM/csrc/apis/attention.hpp
self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2),
dtype=torch.int32,
device=self.device)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
device = self.device
block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc = common_attn_metadata.query_start_loc
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc,
common_attn_metadata.seq_lens[reqs_start:])
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
assert total_seq_lens < self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32, device=device),
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
]).to(torch.int32).cuda()
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=common_attn_metadata.max_query_len,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
)
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment