Commit c0707728 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_tbo' into 'v0.9.2-dev'

fix the performance issue of tbo pd separation

See merge request dcutoolkit/deeplearing/vllm!208
parents 4a80b456 f8985b96
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
...@@ -412,7 +413,10 @@ def unified_attention( ...@@ -412,7 +413,10 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache, output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata) attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache) if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output return output
...@@ -457,8 +461,10 @@ def unified_attention_with_output( ...@@ -457,8 +461,10 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale) output_scale=output_scale)
if envs.VLLM_ENABLE_TBO:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
...@@ -267,6 +268,8 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -267,6 +268,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx) Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise. if MLA is not used, and (num_pages, page_size, xxx) otherwise.
""" """
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata): if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1] num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer.reshape(num_pages * page_size, -1)[slot_mapping,
......
...@@ -162,6 +162,14 @@ def init_two_batch_overlap(): ...@@ -162,6 +162,14 @@ def init_two_batch_overlap():
tbo_obj_v1 = TwoBatchOverlap() tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread() tbo_obj_v1.init_tbo_thread()
def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
from vllm.attention.layer import maybe_save_kv_layer_to_connector
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
return
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj): def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment