Commit d3733d54 authored by xuxz's avatar xuxz
Browse files

[PD]支持glm5 model

parent 41d4c5c1
...@@ -242,7 +242,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1): ...@@ -242,7 +242,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1):
request_id (str): request id for log request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or dst_kv_cache_layer.ndim == 3:
num_pages = dst_kv_cache_layer_shape[0] num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape( dst_kv_cache_layer = dst_kv_cache_layer.reshape(
...@@ -327,7 +327,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1): ...@@ -327,7 +327,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1):
self.du_swift_engine.pool.free(addr) self.du_swift_engine.pool.free(addr)
else: else:
dst_kv_cache_layer_shape = kv_cache_layer.shape dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) :
num_pages = dst_kv_cache_layer_shape[0] num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous() assert kv_cache_layer.is_contiguous()
...@@ -423,7 +423,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1): ...@@ -423,7 +423,7 @@ class DuSwiftConnectorDp(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata) or kv_layer.ndim == 3:
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...] ...]
......
...@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri ...@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
from lightop import op, gemmopt from lightop import op, gemmopt
from vllm.attention.utils.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
elif current_platform.is_xpu(): elif current_platform.is_xpu():
...@@ -27,10 +31,10 @@ elif current_platform.is_xpu(): ...@@ -27,10 +31,10 @@ elif current_platform.is_xpu():
logger = init_logger(__name__) logger = init_logger(__name__)
@maybe_transfer_kv_layer
def sparse_attn_indexer( def sparse_attn_indexer(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, layer_name: str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
...@@ -56,7 +60,7 @@ def sparse_attn_indexer( ...@@ -56,7 +60,7 @@ def sparse_attn_indexer(
) )
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
hidden_states, hidden_states,
k_cache_prefix, layer_name,
kv_cache, kv_cache,
q_fp8, q_fp8,
k, k,
...@@ -69,7 +73,7 @@ def sparse_attn_indexer( ...@@ -69,7 +73,7 @@ def sparse_attn_indexer(
total_seq_lens, total_seq_lens,
topk_indices_buffer, topk_indices_buffer,
) )
attn_metadata = attn_metadata[k_cache_prefix] attn_metadata = attn_metadata[layer_name]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
...@@ -282,7 +286,7 @@ def sparse_attn_indexer( ...@@ -282,7 +286,7 @@ def sparse_attn_indexer(
def sparse_attn_indexer_fake( def sparse_attn_indexer_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
k_cache_prefix: str, layer_name: str,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
try: try:
__version__ = "0.15.1" __version__ = "0.15.1"
__version_tuple__ = (0, 15, 1) __version_tuple__ = (0, 15, 1)
__hcu_version__ = f'0.15.1+das.opt1.alpha.b40b83e.dtk2604' __hcu_version__ = f'0.15.1+das.opt1.alpha.9bfbaf9.dtk2604'
from vllm.version import __version__, __version_tuple__, __hcu_version__ from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment