Commit 6e9157c4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

tbo增加ds int8支持,修改tb zerooverhead环境变量到envs.py中统一管理

See merge request dcutoolkit/deeplearing/vllm!120
parents cb563bb5 dbbb148b
...@@ -6,8 +6,8 @@ from contextlib import contextmanager ...@@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
...@@ -81,7 +81,7 @@ class MQLLMEngine: ...@@ -81,7 +81,7 @@ class MQLLMEngine:
# the python object to be reused again. # the python object to be reused again.
kwargs['use_cached_outputs'] = True kwargs['use_cached_outputs'] = True
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
self.engine = ZeroOverheadEngine(*args, **kwargs) self.engine = ZeroOverheadEngine(*args, **kwargs)
else: else:
self.engine = LLMEngine(*args, **kwargs) self.engine = LLMEngine(*args, **kwargs)
......
...@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -248,7 +248,7 @@ class LLM: ...@@ -248,7 +248,7 @@ class LLM:
) )
# Create the Engine (autoselects V0 vs V1) # Create the Engine (autoselects V0 vs V1)
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
self.llm_engine = ZeroOverheadEngine.from_engine_args( self.llm_engine = ZeroOverheadEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else: else:
......
...@@ -125,6 +125,8 @@ if TYPE_CHECKING: ...@@ -125,6 +125,8 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_ZERO_OVERHEAD: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -797,10 +799,21 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -797,10 +799,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENFORCE_EAGER_BS_THRESHOLD": "VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")), lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# If set, vLLM can avoid Device2Host copy during MLA prefill phase # Enable two batch overlap.
"VLLM_HAS_CONTEXT_DEFAULT": "VLLM_ENABLE_TBO":
lambda: bool(int(os.environ.get("VLLM_HAS_CONTEXT_DEFAULT", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -30,10 +30,6 @@ forward_start_time: float = 0 ...@@ -30,10 +30,6 @@ forward_start_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
def is_enable_tbo():
return enable_tbo
@dataclass @dataclass
class DPMetadata: class DPMetadata:
cu_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor
...@@ -55,7 +51,7 @@ _forward_context: Optional[ForwardContext] = None ...@@ -55,7 +51,7 @@ _forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext: def get_forward_context() -> ForwardContext:
if is_enable_tbo(): if envs.VLLM_ENABLE_TBO:
forward_context = get_tbo_forward_context() forward_context = get_tbo_forward_context()
"""Get the current forward context.""" """Get the current forward context."""
assert forward_context is not None, ( assert forward_context is not None, (
...@@ -125,7 +121,7 @@ def set_forward_context(attn_metadata: Any, ...@@ -125,7 +121,7 @@ def set_forward_context(attn_metadata: Any,
kv_connector = get_kv_transfer_group() kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1) assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context) kv_connector.start_load_kv(_forward_context)
if is_enable_tbo(): if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context) set_tbo_forward_context(_forward_context)
try: try:
yield yield
...@@ -171,5 +167,5 @@ def set_forward_context(attn_metadata: Any, ...@@ -171,5 +167,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector.wait_for_save() kv_connector.wait_for_save()
_forward_context = prev_context _forward_context = prev_context
if is_enable_tbo(): if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context) set_tbo_forward_context(_forward_context)
...@@ -554,9 +554,8 @@ class FusedMoE(torch.nn.Module): ...@@ -554,9 +554,8 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
...@@ -940,7 +939,7 @@ class FusedMoE(torch.nn.Module): ...@@ -940,7 +939,7 @@ class FusedMoE(torch.nn.Module):
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.) # Default set to False. (May have to add shared expert outputs.)
if self.enable_tbo: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import vllm.envs as envs
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
...@@ -1237,9 +1237,8 @@ class RowParallelLinear(LinearBase): ...@@ -1237,9 +1237,8 @@ class RowParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -1310,7 +1309,7 @@ class RowParallelLinear(LinearBase): ...@@ -1310,7 +1309,7 @@ class RowParallelLinear(LinearBase):
input_parallel, input_parallel,
bias=bias_) bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
if self.enable_tbo: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, ...@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -39,7 +38,7 @@ def get_sampler() -> torch.nn.Module: ...@@ -39,7 +38,7 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed # Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler() return V1Sampler()
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.sampler import ZeroOverheadSampler from vllm.zero_overhead.sampler import ZeroOverheadSampler
return ZeroOverheadSampler() return ZeroOverheadSampler()
return Sampler() return Sampler()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple
import vllm.envs as envs
import os import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -283,9 +283,8 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -283,9 +283,8 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_padded, self.num_embeddings_padded,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
@classmethod @classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
...@@ -437,7 +436,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -437,7 +436,7 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
if self.enable_tbo: if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel) output = self.tbo_all_reduce(output_parallel)
else: else:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -155,9 +155,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -155,9 +155,8 @@ class DeepseekV2MoE(nn.Module):
reduce_results=False, reduce_results=False,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -191,7 +190,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -191,7 +190,7 @@ class DeepseekV2MoE(nn.Module):
# final_hidden_states = final_hidden_states + shared_output \ # final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor) # * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if self.enable_tbo: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
import os import os
import re import re
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import vllm.envs as envs
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -150,9 +150,8 @@ class DeepseekV3MoE(nn.Module): ...@@ -150,9 +150,8 @@ class DeepseekV3MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
) )
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -167,7 +166,7 @@ class DeepseekV3MoE(nn.Module): ...@@ -167,7 +166,7 @@ class DeepseekV3MoE(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
if self.enable_tbo: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else: else:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -54,7 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase ...@@ -54,7 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.zero_overhead.utils import is_zero_overhead import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -207,7 +207,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -207,7 +207,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device # Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle": if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True enable_lm_head_weight_load = True
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
assert False, ( assert False, (
"speculative decoding not support zero overhead scheduler yet" "speculative decoding not support zero overhead scheduler yet"
) )
...@@ -261,7 +261,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -261,7 +261,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.") "target model is not running in eager mode.")
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
return ZeroOverheadSpecDecodeWorker( return ZeroOverheadSpecDecodeWorker(
proposer_worker, proposer_worker,
......
import torch import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.mla.common import MLACommonMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
...@@ -12,6 +13,11 @@ def cumsum(lst): ...@@ -12,6 +13,11 @@ def cumsum(lst):
cum_lst.append(sum) cum_lst.append(sum)
return cum_lst return cum_lst
def is_supported_attention_metadata(atten_metadata):
return isinstance(atten_metadata, ROCmFlashAttentionMetadata) or \
isinstance(atten_metadata, FlashMLAMetadata) or \
isinstance(atten_metadata, MLACommonMetadata)
def split_model_input(model_input, self_device, batch_size_left, batch_size_right): def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])] query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
...@@ -93,6 +99,66 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -93,6 +99,66 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1 selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1 selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
if isinstance(model_input.attn_metadata, MLACommonMetadata):
attn_metadata_left = MLACommonMetadata(
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[0],
seq_lens = seq_lens_left,
seq_lens_tensor = split_seq_lens_tensor[0],
max_prefill_seq_len = max_prefill_seq_len_left,
max_decode_seq_len = max_decode_seq_len_left,
context_lens_tensor = split_context_lens_tensor[0],
block_tables = split_block_tables[0],
max_query_len = max_query_len_left,
max_decode_query_len = max_decode_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
)
attn_metadata_right = MLACommonMetadata(
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[1],
seq_lens = seq_lens_right,
seq_lens_tensor = split_seq_lens_tensor[1],
max_prefill_seq_len = max_prefill_seq_len_right,
max_decode_seq_len = max_decode_seq_len_right,
context_lens_tensor = split_context_lens_tensor[1],
block_tables = split_block_tables[1],
max_query_len = max_query_len_right,
max_decode_query_len = max_decode_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
)
if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata): if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata):
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left] block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:] block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
......
...@@ -3,18 +3,14 @@ import os ...@@ -3,18 +3,14 @@ import os
import queue import queue
import threading import threading
import torch import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import split_model_input from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
from vllm.utils import async_tensor_h2d
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm import envs
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1' enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'
...@@ -22,9 +18,6 @@ tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' ...@@ -22,9 +18,6 @@ tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
def is_enable_tbo():
return enable_tbo
tbo_step_stream = None tbo_step_stream = None
all_reduce_stream = None all_reduce_stream = None
...@@ -62,6 +55,7 @@ class TwoBatchOverlap(): ...@@ -62,6 +55,7 @@ class TwoBatchOverlap():
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.left_thread.start() self.left_thread.start()
self.right_thread.start() self.right_thread.start()
logger.info('tbo:two batch overlap threads start')
def finish_thread(self): def finish_thread(self):
if self.left_thread != None: if self.left_thread != None:
...@@ -81,11 +75,9 @@ class TwoBatchOverlap(): ...@@ -81,11 +75,9 @@ class TwoBatchOverlap():
if queue == self.model_input_left_queue: if queue == self.model_input_left_queue:
self.left_tid = tid self.left_tid = tid
is_left_thread = True is_left_thread = True
logger.info('tbo:new thread %d', self.left_tid)
init_tbo_forward_context(True, self.left_tid) init_tbo_forward_context(True, self.left_tid)
else: else:
self.right_tid = tid self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid) init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream): with torch.cuda.stream(tbo_step_stream):
while True: while True:
...@@ -177,7 +169,7 @@ class TwoBatchOverlap(): ...@@ -177,7 +169,7 @@ class TwoBatchOverlap():
tbo_obj = None tbo_obj = None
def init_two_batch_overlap(): def init_two_batch_overlap():
if enable_tbo: if envs.VLLM_ENABLE_TBO:
global tbo_obj global tbo_obj
if tbo_obj == None: if tbo_obj == None:
tbo_obj = TwoBatchOverlap() tbo_obj = TwoBatchOverlap()
...@@ -189,7 +181,7 @@ def finish_two_batch_overlap(): ...@@ -189,7 +181,7 @@ def finish_two_batch_overlap():
tbo_obj = None tbo_obj = None
def tbo_all_reduce(obj): def tbo_all_reduce(obj):
if enable_tbo and tbo_obj != None and tbo_obj.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
if not tbo_one_stream: if not tbo_one_stream:
if tid == tbo_obj.left_tid: if tid == tbo_obj.left_tid:
...@@ -219,14 +211,14 @@ def tbo_model_executable( ...@@ -219,14 +211,14 @@ def tbo_model_executable(
seqlen_agnostic_kwargs, seqlen_agnostic_kwargs,
model_kwargs, model_kwargs,
): ):
init_two_batch_overlap() is_support = is_supported_attention_metadata(model_input.attn_metadata)
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata) if not is_support:
is_mla_fa = isinstance(model_input.attn_metadata, FlashMLAMetadata) logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens) batch_size = len(model_input.attn_metadata.seq_lens)
if batch_size == 1 or \ if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \ (not model_input.is_prompt and not enable_tbo_decode) or \
not (is_rocm_fa or is_mla_fa) or \ not is_support or \
is_cuda_graph_decode: is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine): vllm_config, virtual_engine):
...@@ -241,6 +233,7 @@ def tbo_model_executable( ...@@ -241,6 +233,7 @@ def tbo_model_executable(
) )
return hidden_or_intermediate_states return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable') profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap()
tbo_obj.tbo_running = True tbo_obj.tbo_running = True
tbo_obj.left_first = True tbo_obj.left_first = True
batch_size_left = int(batch_size / 2) batch_size_left = int(batch_size / 2)
......
...@@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import is_enable_tbo, tbo_model_executable from vllm.two_batch_overlap.two_batch_overlap import tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists, async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo, is_pin_memory_available, supports_dynamo,
...@@ -61,7 +61,6 @@ from vllm.worker.model_runner_base import ( ...@@ -61,7 +61,6 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.utils import is_zero_overhead
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -1640,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1640,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead(): if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder _builder_cls = ZeroOverheadModelInputForGpuBuilder
...@@ -1779,7 +1778,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1779,7 +1778,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start.record() model_forward_start.record()
if not bypass_model_exec: if not bypass_model_exec:
if is_enable_tbo(): if envs.VLLM_ENABLE_TBO:
hidden_or_intermediate_states = tbo_model_executable( hidden_or_intermediate_states = tbo_model_executable(
model_input, model_input,
self.vllm_config, self.vllm_config,
......
...@@ -3,15 +3,12 @@ ...@@ -3,15 +3,12 @@
from enum import Enum from enum import Enum
import os import os
import torch import torch
import vllm.envs as envs
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
def is_zero_overhead():
return zero_overhead
def is_zero_no_thread(): def is_zero_no_thread():
return zero_no_thread and zero_overhead return zero_no_thread and envs.VLLM_ZERO_OVERHEAD
class SpecStepKind(Enum): class SpecStepKind(Enum):
KIND_DEFAULT = 0 KIND_DEFAULT = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment