Commit 6e9157c4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

tbo增加ds int8支持,修改tb zerooverhead环境变量到envs.py中统一管理

See merge request dcutoolkit/deeplearing/vllm!120
parents cb563bb5 dbbb148b
......@@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
import zmq
from vllm import AsyncEngineArgs, SamplingParams
......@@ -81,7 +81,7 @@ class MQLLMEngine:
# the python object to be reused again.
kwargs['use_cached_outputs'] = True
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
self.engine = ZeroOverheadEngine(*args, **kwargs)
else:
self.engine = LLMEngine(*args, **kwargs)
......
......@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__)
......@@ -248,7 +248,7 @@ class LLM:
)
# Create the Engine (autoselects V0 vs V1)
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
self.llm_engine = ZeroOverheadEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else:
......
......@@ -125,6 +125,8 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_ZERO_OVERHEAD: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -796,11 +798,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# If set, vLLM can avoid Device2Host copy during MLA prefill phase
# Enable two batch overlap.
"VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.environ.get("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
}
# end-env-vars-definition
......
......@@ -30,10 +30,6 @@ forward_start_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
def is_enable_tbo():
return enable_tbo
@dataclass
class DPMetadata:
cu_tokens_across_dp_cpu: torch.Tensor
......@@ -55,7 +51,7 @@ _forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext:
if is_enable_tbo():
if envs.VLLM_ENABLE_TBO:
forward_context = get_tbo_forward_context()
"""Get the current forward context."""
assert forward_context is not None, (
......@@ -125,7 +121,7 @@ def set_forward_context(attn_metadata: Any,
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context)
if is_enable_tbo():
if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context)
try:
yield
......@@ -171,5 +167,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector.wait_for_save()
_forward_context = prev_context
if is_enable_tbo():
if envs.VLLM_ENABLE_TBO:
set_tbo_forward_context(_forward_context)
......@@ -554,9 +554,8 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
......@@ -940,7 +939,7 @@ class FusedMoE(torch.nn.Module):
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
if self.enable_tbo:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
......
......@@ -3,7 +3,7 @@
import itertools
from abc import abstractmethod
from typing import Any, Literal, Optional, Union
import vllm.envs as envs
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter
......@@ -1237,9 +1237,8 @@ class RowParallelLinear(LinearBase):
})
else:
self.register_parameter("bias", None)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -1310,7 +1309,7 @@ class RowParallelLinear(LinearBase):
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
if self.enable_tbo:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
......
......@@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
......@@ -39,7 +38,7 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.sampler import ZeroOverheadSampler
return ZeroOverheadSampler()
return Sampler()
......
......@@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import vllm.envs as envs
import os
import torch
import torch.nn.functional as F
......@@ -283,9 +283,8 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
......@@ -437,7 +436,7 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
if self.enable_tbo:
if envs.VLLM_ENABLE_TBO:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel)
......
......@@ -155,9 +155,8 @@ class DeepseekV2MoE(nn.Module):
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -191,7 +190,7 @@ class DeepseekV2MoE(nn.Module):
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if self.enable_tbo:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
......
......@@ -23,7 +23,7 @@
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import vllm.envs as envs
import torch
from torch import nn
from transformers import PretrainedConfig
......@@ -150,9 +150,8 @@ class DeepseekV3MoE(nn.Module):
quant_config=quant_config,
reduce_results=False,
)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -167,7 +166,7 @@ class DeepseekV3MoE(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.enable_tbo:
if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states)
else:
final_hidden_states = tensor_model_parallel_all_reduce(
......
......@@ -54,7 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.zero_overhead.utils import is_zero_overhead
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -207,7 +207,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
assert False, (
"speculative decoding not support zero overhead scheduler yet"
)
......@@ -261,7 +261,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
return ZeroOverheadSpecDecodeWorker(
proposer_worker,
......
import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.mla.common import MLACommonMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import async_tensor_h2d
......@@ -12,6 +13,11 @@ def cumsum(lst):
cum_lst.append(sum)
return cum_lst
def is_supported_attention_metadata(atten_metadata):
return isinstance(atten_metadata, ROCmFlashAttentionMetadata) or \
isinstance(atten_metadata, FlashMLAMetadata) or \
isinstance(atten_metadata, MLACommonMetadata)
def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
......@@ -93,6 +99,66 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
if isinstance(model_input.attn_metadata, MLACommonMetadata):
attn_metadata_left = MLACommonMetadata(
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[0],
seq_lens = seq_lens_left,
seq_lens_tensor = split_seq_lens_tensor[0],
max_prefill_seq_len = max_prefill_seq_len_left,
max_decode_seq_len = max_decode_seq_len_left,
context_lens_tensor = split_context_lens_tensor[0],
block_tables = split_block_tables[0],
max_query_len = max_query_len_left,
max_decode_query_len = max_decode_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
)
attn_metadata_right = MLACommonMetadata(
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
input_positions = split_input_positions[1],
seq_lens = seq_lens_right,
seq_lens_tensor = split_seq_lens_tensor[1],
max_prefill_seq_len = max_prefill_seq_len_right,
max_decode_seq_len = max_decode_seq_len_right,
context_lens_tensor = split_context_lens_tensor[1],
block_tables = split_block_tables[1],
max_query_len = max_query_len_right,
max_decode_query_len = max_decode_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
head_dim = model_input.attn_metadata.head_dim,
is_profile_run = model_input.attn_metadata.is_profile_run,
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
)
if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata):
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
......
......@@ -3,18 +3,14 @@ import os
import queue
import threading
import torch
from vllm.attention.backends.flashmla import FlashMLAMetadata
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import split_model_input
from vllm.utils import async_tensor_h2d
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
from vllm.logger import init_logger
from vllm.profiler.prof import profile
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
from vllm import envs
enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'
......@@ -22,9 +18,6 @@ tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__)
def is_enable_tbo():
return enable_tbo
tbo_step_stream = None
all_reduce_stream = None
......@@ -62,6 +55,7 @@ class TwoBatchOverlap():
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.left_thread.start()
self.right_thread.start()
logger.info('tbo:two batch overlap threads start')
def finish_thread(self):
if self.left_thread != None:
......@@ -81,11 +75,9 @@ class TwoBatchOverlap():
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
logger.info('tbo:new thread %d', self.left_tid)
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
while True:
......@@ -177,7 +169,7 @@ class TwoBatchOverlap():
tbo_obj = None
def init_two_batch_overlap():
if enable_tbo:
if envs.VLLM_ENABLE_TBO:
global tbo_obj
if tbo_obj == None:
tbo_obj = TwoBatchOverlap()
......@@ -189,7 +181,7 @@ def finish_two_batch_overlap():
tbo_obj = None
def tbo_all_reduce(obj):
if enable_tbo and tbo_obj != None and tbo_obj.tbo_running:
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident()
if not tbo_one_stream:
if tid == tbo_obj.left_tid:
......@@ -219,14 +211,14 @@ def tbo_model_executable(
seqlen_agnostic_kwargs,
model_kwargs,
):
init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_mla_fa = isinstance(model_input.attn_metadata, FlashMLAMetadata)
is_support = is_supported_attention_metadata(model_input.attn_metadata)
if not is_support:
logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens)
if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \
not (is_rocm_fa or is_mla_fa) or \
not is_support or \
is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine):
......@@ -241,6 +233,7 @@ def tbo_model_executable(
)
return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap()
tbo_obj.tbo_running = True
tbo_obj.left_first = True
batch_size_left = int(batch_size / 2)
......
......@@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import is_enable_tbo, tbo_model_executable
from vllm.two_batch_overlap.two_batch_overlap import tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo,
......@@ -61,7 +61,6 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.utils import is_zero_overhead
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -1640,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
if envs.VLLM_ZERO_OVERHEAD:
from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder
......@@ -1779,7 +1778,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start.record()
if not bypass_model_exec:
if is_enable_tbo():
if envs.VLLM_ENABLE_TBO:
hidden_or_intermediate_states = tbo_model_executable(
model_input,
self.vllm_config,
......
......@@ -3,15 +3,12 @@
from enum import Enum
import os
import torch
import vllm.envs as envs
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
def is_zero_overhead():
return zero_overhead
def is_zero_no_thread():
return zero_no_thread and zero_overhead
return zero_no_thread and envs.VLLM_ZERO_OVERHEAD
class SpecStepKind(Enum):
KIND_DEFAULT = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment