Commit 109c414a authored by zhuwenwen's avatar zhuwenwen
Browse files

fix the performance issue of tbo pd separation

parent e37d6cc3
......@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
......@@ -480,7 +481,10 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
......@@ -528,7 +532,10 @@ def unified_attention_with_output(
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake(
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
......@@ -262,6 +263,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...]
......
......@@ -162,6 +162,14 @@ def init_two_batch_overlap():
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
from vllm.attention.layer import maybe_save_kv_layer_to_connector
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
return
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment