"vllm/vscode:/vscode.git/clone" did not exist on "22de45235c6dd14e901e089971635ec655d5fbe0"
Commit 109c414a authored by zhuwenwen's avatar zhuwenwen
Browse files

fix the performance issue of tbo pd separation

parent e37d6cc3
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -480,6 +481,9 @@ def unified_attention( ...@@ -480,6 +481,9 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache, output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata) attn_metadata)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output return output
...@@ -528,6 +532,9 @@ def unified_attention_with_output( ...@@ -528,6 +532,9 @@ def unified_attention_with_output(
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale) output_block_scale=output_block_scale)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
...@@ -262,6 +263,8 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -262,6 +263,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
torch.Tensor: A tensor containing the extracted KV slices. torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported. Returns None if the layout is unsupported.
""" """
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if (isinstance(attn_metadata, MLACommonMetadata) if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...] return layer[block_ids, ...]
......
...@@ -162,6 +162,14 @@ def init_two_batch_overlap(): ...@@ -162,6 +162,14 @@ def init_two_batch_overlap():
tbo_obj_v1 = TwoBatchOverlap() tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread() tbo_obj_v1.init_tbo_thread()
def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
from vllm.attention.layer import maybe_save_kv_layer_to_connector
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
return
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj): def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment