Commit 7462218e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1'

parents 6ccd3f47 1cec5e62
...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
...@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul_opt);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul_opt);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul_opt);
// Layernorm // Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( ops.def(
...@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"); "float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
...@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras). // (supports multiple loras).
ops.def( ops.def(
...@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"); " Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantized GEMM for AQLM. // Quantized GEMM for AQLM.
......
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. if __name__ == '__main__':
prompts = [ # Sample prompts.
prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser.add_argument('--template', type=str, help="Path to template")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer = AutoTokenizer.from_pretrained(args.model)
try:
f = open(args.template,'r')
tokenizer.chat_template = f.read()
except Exception as e:
print('except:',e)
finally:
f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history):
prompt = ""
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\n{model_name}:{response}"
return prompt
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
history.append({"role": "user", "content": query})
new_query = tokenizer.apply_chat_template(history, tokenize=False)
example_input = {
"prompt": new_query,
"stream": False,
"temperature": 0.0,
"request_id": 0,
}
results_generator = engine.generate(
example_input["prompt"],
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
example_input["request_id"]
)
start = 0
end = 0
response = ""
async def process_results():
async for output in results_generator:
global end
global start
global response
print(output.outputs[0].text[start:], end="", flush=True)
length = len(output.outputs[0].text)
start = length
response = output.outputs[0].text
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
...@@ -2,6 +2,5 @@ ...@@ -2,6 +2,5 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for AMD GPUs # Dependencies for AMD GPUs
ray == 2.9.1 ray >= 2.10.0
# ray >= 2.10.0
pytest-asyncio pytest-asyncio
...@@ -18,6 +18,9 @@ from typing import Optional, Union ...@@ -18,6 +18,9 @@ from typing import Optional, Union
import subprocess import subprocess
from pathlib import Path from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path) spec = importlib.util.spec_from_file_location(module_name, path)
...@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str: ...@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_sha(root: Union[str, Path]) -> str: def get_sha(root: Union[str, Path]) -> str:
try: try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip() return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception: except Exception:
return 'Unknown' return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str: def get_version_add(sha: Optional[str] = None) -> str:
vllm_root = os.path.dirname(os.path.abspath(__file__)) vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py") add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
if add_git_version:
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
version = 'das1.1.git' + sha[:7] version = 'das.opt1' + sha[:7]
else:
# abi version version = 'das.opt1'
version += "." + get_abi()
# dtk version # dtk version
if os.getenv("ROCM_PATH"): if os.getenv("ROCM_PATH"):
...@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version") rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file: with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines() lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "") rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write("__version__='0.5.0.post1'\n") file.write("__version__='0.5.0.post1'\n")
file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version)) file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version))
......
...@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute( ...@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
def test_preemption( def test_preemption(
caplog_vllm, caplog_vllm,
...@@ -118,7 +119,8 @@ def test_preemption( ...@@ -118,7 +119,8 @@ def test_preemption(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4]) @pytest.mark.parametrize("beam_width", [4])
def test_swap( def test_swap(
...@@ -176,7 +178,8 @@ def test_swap( ...@@ -176,7 +178,8 @@ def test_swap(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4]) @pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible( def test_swap_infeasible(
...@@ -220,7 +223,8 @@ def test_swap_infeasible( ...@@ -220,7 +223,8 @@ def test_swap_infeasible(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible( def test_preemption_infeasible(
vllm_runner, vllm_runner,
......
...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention( ...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = output.squeeze(0) output = output.squeeze(0)
......
...@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask ...@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
...@@ -158,6 +159,7 @@ def test_contexted_kv_attention( ...@@ -158,6 +159,7 @@ def test_contexted_kv_attention(
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
if not is_hip():
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp() attn_op = xops.fmha.cutlass.FwOp()
...@@ -373,6 +375,8 @@ def test_contexted_kv_attention_alibi( ...@@ -373,6 +375,8 @@ def test_contexted_kv_attention_alibi(
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
if not is_hip():
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function, # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
......
import contextlib import contextlib
import functools import functools
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try:
from lmslim import quant_ops
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try: try:
import vllm._C import vllm._C
except ImportError as e: except ImportError as e:
...@@ -58,6 +62,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ...@@ -58,6 +62,18 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x) torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x) torch.ops._C.gelu_fast(out, x)
...@@ -125,6 +141,65 @@ def paged_attention_v2( ...@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops # pos encoding ops
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, positions: torch.Tensor,
...@@ -158,9 +233,30 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, ...@@ -158,9 +233,30 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops # quantization ops
# awq # awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
...@@ -168,23 +264,56 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, ...@@ -168,23 +264,56 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy) thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, # def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: # scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq # gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm # squeezellm
......
...@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 if self.use_flash_attn_auto:
triton_attention) from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
self.attn_func = triton_attention flash_attn_varlen_func)
self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
...@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
out, _ = self.attn_func( if self.use_flash_attn_auto:
query, if prefill_meta.max_prefill_seq_len >= 8000:
key, out = self.attn_func_triton(
value, q=query,
None, k=key,
prefill_meta.seq_start_loc, v=value,
prefill_meta.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len, cu_seqlens_k=prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len, max_seqlens_q=prefill_meta.max_prefill_seq_len,
True, max_seqlens_k=prefill_meta.max_prefill_seq_len,
self.scale, softmax_scale=self.scale,
causal=True,
) )
else:
out = self.attn_func_ck(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround. # Interleave for MQA workaround.
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
import vllm.envs as envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
...@@ -122,6 +123,33 @@ class PagedAttention: ...@@ -122,6 +123,33 @@ class PagedAttention:
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
...@@ -156,6 +184,38 @@ class PagedAttention: ...@@ -156,6 +184,38 @@ class PagedAttention:
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
......
...@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0": ...@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window=None): sliding_window=None):
cap = torch.cuda.get_device_capability() cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64 BLOCK = 32 if cap[0] >= 8 else 32
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
...@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0": ...@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if sliding_window is None or sliding_window <= 0: if sliding_window is None or sliding_window <= 0:
sliding_window = 0 sliding_window = 0
num_warps = 8 if Lk <= 64 else 8 num_warps = 8 if Lk <= 64 else 4
if alibi_slopes is not None: if alibi_slopes is not None:
_fwd_kernel_alibi[grid]( _fwd_kernel_alibi[grid](
q, q,
......
This diff is collapsed.
...@@ -173,7 +173,7 @@ class ModelConfig: ...@@ -173,7 +173,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"] rocm_supported_quantization = ["gptq", "squeezellm","awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -279,6 +279,12 @@ class ModelConfig: ...@@ -279,6 +279,12 @@ class ModelConfig:
return self.hf_text_config.hidden_size return self.hf_text_config.hidden_size
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"): if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
......
...@@ -232,6 +232,8 @@ class LLMEngine: ...@@ -232,6 +232,8 @@ class LLMEngine:
load_config=load_config, load_config=load_config,
) )
init_success = False
try:
if not self.model_config.embedding_mode: if not self.model_config.embedding_mode:
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -288,6 +290,13 @@ class LLMEngine: ...@@ -288,6 +290,13 @@ class LLMEngine:
max_model_len=self.model_config.max_model_len) max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config) self.stat_logger.info("cache_config", self.cache_config)
tokenizer_group = self.get_tokenizer_group()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return tokenizer_group.get_lora_tokenizer(
sequence.lora_request)
# Create sequence output processor, e.g. for beam search or # Create sequence output processor, e.g. for beam search or
# speculative decoding. # speculative decoding.
self.output_processor = ( self.output_processor = (
...@@ -296,12 +305,18 @@ class LLMEngine: ...@@ -296,12 +305,18 @@ class LLMEngine:
self.detokenizer, self.detokenizer,
self.scheduler, self.scheduler,
self.seq_counter, self.seq_counter,
self.get_tokenizer_for_seq, get_tokenizer_for_seq,
stop_checker=StopChecker( stop_checker=StopChecker(
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq, get_tokenizer_for_seq,
), ),
)) ))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -393,10 +408,10 @@ class LLMEngine: ...@@ -393,10 +408,10 @@ class LLMEngine:
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None) return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self, # def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": # sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer( # return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request) # sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict( init_kwargs = dict(
...@@ -785,7 +800,8 @@ class LLMEngine: ...@@ -785,7 +800,8 @@ class LLMEngine:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs: # if not request_outputs:
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there are # Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in # more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks # torch.distributed ops which may otherwise timeout, and unblocks
......
...@@ -9,6 +9,9 @@ if TYPE_CHECKING: ...@@ -9,6 +9,9 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
...@@ -27,7 +30,7 @@ if TYPE_CHECKING: ...@@ -27,7 +30,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")), ("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
...@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# timeout for each iteration in the engine # timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "VLLM_ENGINE_ITERATION_TIMEOUT_S":
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "120")),
# API key for VLLM API server # API key for VLLM API server
"VLLM_API_KEY": "VLLM_API_KEY":
...@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD": "VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
# Timeout for fetching images when serving multimodal models # Timeout for fetching images when serving multimodal models
# Default is 5 seconds # Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread): ...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)""" """Handle results from all workers (in background thread)"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue() self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread): ...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'], def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler): result_handler: ResultHandler):
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers self.workers = workers
self.result_handler = result_handler self.result_handler = result_handler
self._close = False self._close = False
...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread): ...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True self._close = True
# Kill / cleanup all workers # Kill / cleanup all workers
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers: for worker in self.workers:
process = worker.process process = worker.process
if process.sentinel in dead_sentinels: if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S) process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0: if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s", logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode) process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers # Cleanup any remaining workers
logger.info("Killing local vLLM worker processes") # logger.info("Killing local vLLM worker processes")
for worker in self.workers: for worker in self.workers:
worker.kill_worker() worker.kill_worker()
# Must be done after worker task queues are all closed # Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment