Commit 7462218e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1'

parents 6ccd3f47 1cec5e62
...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
...@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul_opt);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul_opt);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul_opt);
// Layernorm // Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( ops.def(
...@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"); "float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
...@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras). // (supports multiple loras).
ops.def( ops.def(
...@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"); " Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantized GEMM for AQLM. // Quantized GEMM for AQLM.
......
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. if __name__ == '__main__':
prompts = [ # Sample prompts.
"Hello, my name is", prompts = [
"The president of the United States is", "Hello, my name is",
"The capital of France is", "The president of the United States is",
"The future of AI is", "The capital of France is",
] "The future of AI is",
# Create a sampling params object. ]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser.add_argument('--template', type=str, help="Path to template")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer = AutoTokenizer.from_pretrained(args.model)
try:
f = open(args.template,'r')
tokenizer.chat_template = f.read()
except Exception as e:
print('except:',e)
finally:
f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history):
prompt = ""
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\n{model_name}:{response}"
return prompt
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
history.append({"role": "user", "content": query})
new_query = tokenizer.apply_chat_template(history, tokenize=False)
example_input = {
"prompt": new_query,
"stream": False,
"temperature": 0.0,
"request_id": 0,
}
results_generator = engine.generate(
example_input["prompt"],
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
example_input["request_id"]
)
start = 0
end = 0
response = ""
async def process_results():
async for output in results_generator:
global end
global start
global response
print(output.outputs[0].text[start:], end="", flush=True)
length = len(output.outputs[0].text)
start = length
response = output.outputs[0].text
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
...@@ -2,6 +2,5 @@ ...@@ -2,6 +2,5 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for AMD GPUs # Dependencies for AMD GPUs
ray == 2.9.1 ray >= 2.10.0
# ray >= 2.10.0
pytest-asyncio pytest-asyncio
...@@ -18,6 +18,9 @@ from typing import Optional, Union ...@@ -18,6 +18,9 @@ from typing import Optional, Union
import subprocess import subprocess
from pathlib import Path from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path) spec = importlib.util.spec_from_file_location(module_name, path)
...@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str: ...@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_sha(root: Union[str, Path]) -> str: def get_sha(root: Union[str, Path]) -> str:
try: try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip() return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception: except Exception:
return 'Unknown' return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str: def get_version_add(sha: Optional[str] = None) -> str:
vllm_root = os.path.dirname(os.path.abspath(__file__)) vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py") add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
if sha != 'Unknown': if add_git_version:
if sha is None: if sha != 'Unknown':
sha = get_sha(vllm_root) if sha is None:
version = 'das1.1.git' + sha[:7] sha = get_sha(vllm_root)
version = 'das.opt1' + sha[:7]
# abi version else:
version += "." + get_abi() version = 'das.opt1'
# dtk version # dtk version
if os.getenv("ROCM_PATH"): if os.getenv("ROCM_PATH"):
...@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version") rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file: with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines() lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "") rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write("__version__='0.5.0.post1'\n") file.write("__version__='0.5.0.post1'\n")
file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version)) file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version))
......
...@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute( ...@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
def test_preemption( def test_preemption(
caplog_vllm, caplog_vllm,
...@@ -118,7 +119,8 @@ def test_preemption( ...@@ -118,7 +119,8 @@ def test_preemption(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4]) @pytest.mark.parametrize("beam_width", [4])
def test_swap( def test_swap(
...@@ -176,7 +178,8 @@ def test_swap( ...@@ -176,7 +178,8 @@ def test_swap(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4]) @pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible( def test_swap_infeasible(
...@@ -220,7 +223,8 @@ def test_swap_infeasible( ...@@ -220,7 +223,8 @@ def test_swap_infeasible(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) # @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible( def test_preemption_infeasible(
vllm_runner, vllm_runner,
......
...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention( ...@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = output.squeeze(0) output = output.squeeze(0)
......
...@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask ...@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
...@@ -158,57 +159,58 @@ def test_contexted_kv_attention( ...@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5)) if not is_hip():
scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp() attn_op = xops.fmha.cutlass.FwOp()
if num_kv_heads != num_heads: if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # project the key and value tensors to the desired number of
# heads. # heads.
# #
# see also: vllm/model_executor/layers/attention.py # see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1]) query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1]) num_queries_per_kv, key.shape[-1])
value = value[:, :, value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads, None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1]) num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0) query = query.unsqueeze(0)
key = key.unsqueeze(0) key = key.unsqueeze(0)
value = value.unsqueeze(0) value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens) query_lens, seq_lens)
if sliding_window > 0: if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright( attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window) sliding_window)
output_ref = xops.memory_efficient_attention_forward( output_ref = xops.memory_efficient_attention_forward(
query, query,
key, key,
value, value,
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=attn_op, op=attn_op,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
output_ref = xops.memory_efficient_attention_forward( output_ref = xops.memory_efficient_attention_forward(
query, query,
key, key,
value, value,
attn_bias=attn_bias, attn_bias=attn_bias,
p=0.0, p=0.0,
scale=scale, scale=scale,
op=attn_op, op=attn_op,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape) output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
...@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi( ...@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
if not is_hip():
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# NOTE(DefTruth): In order to reuse _make_alibi_bias function, attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
# we have to pad query tensor before MQA/GQA expanding. output_ref = torch.empty_like(output)
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0 seq_start = 0
query_start = 0 query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len seq_end = seq_start + seq_len
query_end = query_start + query_len query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([ out = xops.memory_efficient_attention_forward(query[:,
torch.zeros( seq_start:seq_end],
seq_len - query_len, num_heads, head_size, dtype=dtype), key[:,
query[query_start:query_end, ...] seq_start:seq_end],
], value[:,
dim=0) seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len seq_start += seq_len
query_start += query_len query_start += query_len
query = query_pad torch.cuda.synchronize()
end_time = time.time()
if num_kv_heads != num_heads: print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
import contextlib import contextlib
import functools import functools
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try:
from lmslim import quant_ops
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try: try:
import vllm._C import vllm._C
except ImportError as e: except ImportError as e:
...@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ...@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x) torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
...@@ -125,6 +141,65 @@ def paged_attention_v2( ...@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops # pos encoding ops
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, positions: torch.Tensor,
...@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, ...@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops # quantization ops
# awq # awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
...@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, ...@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy) thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, # def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: # scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq # gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm # squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
...@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 if self.use_flash_attn_auto:
triton_attention) from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
self.attn_func = triton_attention flash_attn_varlen_func)
logger.debug("Using Triton FA in ROCmBackend") self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
...@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
out, _ = self.attn_func( if self.use_flash_attn_auto:
query, if prefill_meta.max_prefill_seq_len >= 8000:
key, out = self.attn_func_triton(
value, q=query,
None, k=key,
prefill_meta.seq_start_loc, v=value,
prefill_meta.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len, cu_seqlens_k=prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len, max_seqlens_q=prefill_meta.max_prefill_seq_len,
True, max_seqlens_k=prefill_meta.max_prefill_seq_len,
self.scale, softmax_scale=self.scale,
causal=True,
)
else:
out = self.attn_func_ck(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround. # Interleave for MQA workaround.
...@@ -439,4 +491,4 @@ def _sdpa_attention( ...@@ -439,4 +491,4 @@ def _sdpa_attention(
output[start:end, :, :] = sub_out output[start:end, :, :] = sub_out
start = end start = end
return output return output
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0": ...@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window=None): sliding_window=None):
cap = torch.cuda.get_device_capability() cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64 BLOCK = 32 if cap[0] >= 8 else 32
# shape constraints # shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
...@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0": ...@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if sliding_window is None or sliding_window <= 0: if sliding_window is None or sliding_window <= 0:
sliding_window = 0 sliding_window = 0
num_warps = 8 if Lk <= 64 else 8 num_warps = 8 if Lk <= 64 else 4
if alibi_slopes is not None: if alibi_slopes is not None:
_fwd_kernel_alibi[grid]( _fwd_kernel_alibi[grid](
q, q,
......
This diff is collapsed.
...@@ -173,7 +173,7 @@ class ModelConfig: ...@@ -173,7 +173,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"] rocm_supported_quantization = ["gptq", "squeezellm","awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -279,6 +279,12 @@ class ModelConfig: ...@@ -279,6 +279,12 @@ class ModelConfig:
return self.hf_text_config.hidden_size return self.hf_text_config.hidden_size
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"): if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
......
...@@ -315,4 +315,4 @@ class CustomAllreduce: ...@@ -315,4 +315,4 @@ class CustomAllreduce:
self._ptr = 0 self._ptr = 0
def __del__(self): def __del__(self):
self.close() self.close()
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread): ...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)""" """Handle results from all workers (in background thread)"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue() self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread): ...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'], def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler): result_handler: ResultHandler):
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers self.workers = workers
self.result_handler = result_handler self.result_handler = result_handler
self._close = False self._close = False
...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread): ...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True self._close = True
# Kill / cleanup all workers # Kill / cleanup all workers
for worker in self.workers: # for worker in self.workers:
process = worker.process # process = worker.process
if process.sentinel in dead_sentinels: # if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S) # process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0: # if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s", # logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode) # process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers # Cleanup any remaining workers
logger.info("Killing local vLLM worker processes") # logger.info("Killing local vLLM worker processes")
for worker in self.workers: for worker in self.workers:
worker.kill_worker() worker.kill_worker()
# Must be done after worker task queues are all closed # Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment