Commit 7462218e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1'

parents 6ccd3f47 1cec5e62
......@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
......@@ -68,6 +96,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul_opt);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul_opt);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul_opt);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
......@@ -81,6 +121,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
......@@ -89,6 +141,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
......@@ -99,6 +160,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
......
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
if __name__ == '__main__':
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
# Create an LLM.
llm = LLM(model="facebook/opt-125m",trust_remote_code=True, dtype="float16", enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, distributed_executor_backend="ray", dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
from vllm.sampling_params import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
import asyncio
from vllm.utils import FlexibleArgumentParser
from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else:
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser.add_argument('--template', type=str, help="Path to template")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer = AutoTokenizer.from_pretrained(args.model)
try:
f = open(args.template,'r')
tokenizer.chat_template = f.read()
except Exception as e:
print('except:',e)
finally:
f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history):
prompt = ""
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\n{model_name}:{response}"
return prompt
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
history.append({"role": "user", "content": query})
new_query = tokenizer.apply_chat_template(history, tokenize=False)
example_input = {
"prompt": new_query,
"stream": False,
"temperature": 0.0,
"request_id": 0,
}
results_generator = engine.generate(
example_input["prompt"],
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
example_input["request_id"]
)
start = 0
end = 0
response = ""
async def process_results():
async for output in results_generator:
global end
global start
global response
print(output.outputs[0].text[start:], end="", flush=True)
length = len(output.outputs[0].text)
start = length
response = output.outputs[0].text
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
\ No newline at end of file
......@@ -2,6 +2,5 @@
-r requirements-common.txt
# Dependencies for AMD GPUs
ray == 2.9.1
# ray >= 2.10.0
ray >= 2.10.0
pytest-asyncio
......@@ -18,6 +18,9 @@ from typing import Optional, Union
import subprocess
from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
......@@ -317,33 +320,23 @@ def find_version(filepath: str) -> str:
raise RuntimeError("Unable to find version string.")
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_sha(root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_version_add(sha: Optional[str] = None) -> str:
vllm_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(vllm_root, "vllm"), "version.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das1.1.git' + sha[:7]
# abi version
version += "." + get_abi()
if add_git_version:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
version = 'das.opt1' + sha[:7]
else:
version = 'das.opt1'
# dtk version
if os.getenv("ROCM_PATH"):
......@@ -351,12 +344,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "")
rocm_version=lines[0].replace(".", "")
version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write("__version__='0.5.0.post1'\n")
file.write("__dcu_version__='0.5.0.post1+{}'\n".format(version))
......
......@@ -67,7 +67,8 @@ def test_chunked_prefill_recompute(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
caplog_vllm,
......@@ -118,7 +119,8 @@ def test_preemption(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
......@@ -176,7 +178,8 @@ def test_swap(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible(
......@@ -220,7 +223,8 @@ def test_swap_infeasible(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
# @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible(
vllm_runner,
......
......@@ -359,6 +359,8 @@ def test_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = output.squeeze(0)
......
......@@ -9,6 +9,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
......@@ -158,57 +159,58 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
if not is_hip():
scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp()
attn_op = xops.fmha.cutlass.FwOp()
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
......@@ -373,78 +375,80 @@ def test_contexted_kv_attention_alibi(
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
if not is_hip():
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
import contextlib
import functools
from typing import List, Optional, Tuple, Type
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from lmslim import quant_ops
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try:
import vllm._C
except ImportError as e:
......@@ -56,6 +60,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
......@@ -125,6 +141,65 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
......@@ -157,10 +232,31 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops
# awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
......@@ -168,24 +264,57 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
......@@ -228,11 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen >= 8000
self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.use_flash_attn_auto:
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func_triton = flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func_ck = flash_attn_varlen_func
logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
else:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
......@@ -325,18 +339,56 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len,
True,
self.scale,
if self.use_flash_attn_auto:
if prefill_meta.max_prefill_seq_len >= 8000:
out = self.attn_func_triton(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
out = self.attn_func_ck(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
else:
# out, _ = self.attn_func(
# query,
# key,
# value,
# None,
# prefill_meta.seq_start_loc,
# prefill_meta.seq_start_loc,
# prefill_meta.max_prefill_seq_len,
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
......@@ -439,4 +491,4 @@ def _sdpa_attention(
output[start:end, :, :] = sub_out
start = end
return output
return output
\ No newline at end of file
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import argparse
import pytest
import random
import sys
import torch
import triton
import triton.language as tl
torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
class MetaData():
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
dropout_p, return_encoded_softmax = 0.0, False
def __init__(self, sm_scale=1.0, causal=False, dropout_p=0.0, return_encoded_softmax=False):
self.sm_scale = sm_scale
self.causal = causal
self.dropout_p = dropout_p
self.return_encoded_softmax = return_encoded_softmax
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range (0, self.num_contexts):
self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def need_dropout(dropout_p, return_encoded_softmax):
self.dropout_p = dropout_p
self.return_encoded_softmax = return_encoded_softmax
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
if self.varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias == None
# TODO:Remove once dropout is supported with varlen
assert self.dropout_p == 0.0
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
@triton.jit
def cdiv_fn(x,y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def print_gpu(prefix, val=None):
if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)):
if val is not None:
tl.device_print(prefix, val)
else:
tl.device_print(prefix)
@triton.jit
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose = False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
# [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:]
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit
def _attn_fwd_inner(
acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m,
actual_seqlen_k,
actual_seqlen_q,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min, block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
alibi_slope,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr
):
# loop over k, v, and update accumulator
for start_n in range (block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero")
if PRE_LOAD_V:
v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero")
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS:
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
# last step might get wasted but that is okay. check if this masking works For
# that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None,:]
mask = size_n < boundary_m[:,None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero")
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk,1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
if RETURN_ENCODED_SOFTMAX:
tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty))
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty))
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero")
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=2, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
# use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Q, K, V, bias, sm_scale, L, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bn,
stride_az, stride_ah,
cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HK:tl.constexpr,
ACTUAL_BLOCK_DMODEL:tl.constexpr,
MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_ALIBI: tl.constexpr,
BATCH_SIZE: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if (IS_CAUSAL):
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q,
BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is part of
# the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# We store inf to LSE, not -inf because in the bwd pass, we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.
l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here too?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
need_padding = False
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
need_padding = True
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
need_padding = True
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel.
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
if BIAS_TYPE != 0:
b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs
bias_ptr = tl.make_block_ptr(
base=bias + b_offset,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout. In
# this case, we return an invalid pointer so indicate the mask is not valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0)
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
# In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its actual
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, seqlen_k, seqlen_q,
dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope,
# IS_CAUSAL, ....
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if (masked_blocks > 0):
if IS_CAUSAL:
offs_n_causal = offs_n + (seqlen_q - seqlen_k)
else:
offs_n_causal = 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, n_full_blocks))
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m, seqlen_k, seqlen_q,
dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, alibi_slope,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL:
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# This is a > check because mask being 0 blocks the store.
l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0,1))
@triton.jit
def _attn_bwd_preprocess(
Out, DO,
Delta,
stride_oz, stride_oh, stride_om, stride_on,
stride_doz, stride_doh, stride_dom, stride_don,
seqlen_q,
head_dim,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
# off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
# off_n = tl.arange(0, D_HEAD)
off_m = tl.program_id(0) * BLOCK_M
off_h = tl.program_id(1) # head index
off_z = tl.program_id(2) # batch index
num_h = tl.num_programs(1)
o_offset = off_h * stride_oh + off_z * stride_oz
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, head_dim),
strides=(stride_om, stride_on),
offsets=(off_m, 0),
block_shape=(BLOCK_M, D_HEAD),
order=(1, 0)
)
do_offset = off_h * stride_doh + off_z * stride_doz
DO_block_ptr = tl.make_block_ptr(
base=DO + do_offset,
shape=(seqlen_q, head_dim),
strides=(stride_dom, stride_don),
offsets=(off_m, 0),
block_shape=(BLOCK_M, D_HEAD),
order=(1, 0)
)
# load
# o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
# do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32)
do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32)
# compute
delta = tl.sum(o * do, axis=1)
# write-back, shape (q.shape[0] * q.shape[1], q.shape[2])
off_zh = off_z * num_h + off_h * 1
# Check for OOB accesses
delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M)
overflow = off_m + BLOCK_M - seqlen_q
if overflow > 0:
boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32)
mask = boundary > tl.arange(0, BLOCK_M)
tl.store(delta_ptrs, delta, mask=mask)
else:
tl.store(delta_ptrs, delta)
@triton.jit
def _bwd_kernel_dk_dv(
dk, dv,
Q, k, v, sm_scale, alibi_slope,
DO,
M, D,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX, BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_n, start_m, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, BLOCK_DMODEL)
QT_block_ptr = tl.make_block_ptr(
base=Q,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_m),
block_shape=(BLOCK_DMODEL, BLOCK_M1),
order=(0,1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M1, BLOCK_DMODEL),
order=(1,0)
)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(QT_block_ptr)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
kqT = tl.dot(k, qT)
if alibi_slope is not None:
alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True)
kqT += alibi_block * 1.44269504089
pT = tl.math.exp2(kqT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(DO_block_ptr)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
return dk, dv
@triton.jit
def _bwd_kernel_dq(dq, q, K, V,
do, m, D, alibi_slope,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, BLOCK_DMODEL)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
VT_block_ptr = tl.make_block_ptr(
base=V,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(KT_block_ptr)
qk = tl.dot(q, kT)
if alibi_slope is not None:
alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n)
qk += alibi_block * 1.44269504089
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
vT = tl.load(VT_block_ptr)
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.0.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
return dq
@triton.jit
def _attn_bwd(Q, K, V, sm_scale, alibi_slopes,
DO,
DQ, DK, DV,
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d,
# H = 16, N_CTX = 1024
H, N_CTX,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLK_SLICE_FACTOR: tl.constexpr,
USE_ALIBI: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
offs_k = tl.arange(0, BLOCK_DMODEL)
start_n = pid * BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
if USE_ALIBI:
a_offset = bhid
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
# compute dK and dV for blocks close to the diagonal that need to be masked
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _bwd_kernel_dk_dv(
dk, dv,
Q, k, v, sm_scale, alibi_slope,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True
)
# compute dK and dV for blocks that don't need masking further from the diagonal
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
dk, dv = _bwd_kernel_dk_dv(
dk, dv,
Q, k, v, sm_scale, alibi_slope,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False
)
DV_block_ptrs = tl.make_block_ptr(
base=DV,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1,0)
)
tl.store(DV_block_ptrs, dv.to(v.dtype))
# Write back dK.
dk *= sm_scale
DK_block_ptrs = tl.make_block_ptr(
base=DK,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1,0)
)
tl.store(DK_block_ptrs, dk.to(k.dtype))
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _bwd_kernel_dq(dq, q, K, V,
do, m, D, alibi_slope,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _bwd_kernel_dq(dq, q, K, V,
do, m, D, alibi_slope,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False
)
# Write back dQ.
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
dq *= LN2
tl.store(DQ_block_ptr, dq.to(q.dtype))
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, o, metadata):
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
if (metadata.bias is not None):
assert(metadata.bias.numel() < 2 ** 31)
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
import os
if os.environ.get("FLASH_ATTENTION_PRINT_PARAM", "0") == "1":
print(f"triton flash attention: {q.shape=}, {k.shape=}, {v.shape}, {o.shape=}")
print(f"triton flash attention: {q.stride()=}, {k.stride()=}, {v.stride()=}, {o.stride()=}")
print(f"triton flash attention: {metadata=}")
metadata.check_args(q, k, v, o)
if metadata.varlen:
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = metadata.num_contexts
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (
triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']),
nheads_q,
batch
)
# encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
# only. This return holds no useful output aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=torch.float32)
else:
encoded_softmax = None
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if metadata.bias is not None:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
metadata.bias.stride(2), metadata.bias.stride(3))
else:
bias_strides = (0,0,0,0)
if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
attn_fwd[grid](
q, k, v, metadata.bias, metadata.sm_scale, M, o,
*q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides,
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
dropout_p=metadata.dropout_p,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
alibi_slopes = metadata.alibi_slopes,
HQ=nheads_q, HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=metadata.causal,
VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if metadata.bias is None else 1,
USE_ALIBI=False if metadata.alibi_slopes is None else True,
ENABLE_DROPOUT=metadata.dropout_p > 0.0,
RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
BATCH_SIZE= q.shape[0]
)
if os.environ.get("FLASH_ATTENTION_PRINT_PARAM", "0") == "1":
best_config = attn_fwd.get_best_config()
print(f"{best_config.kwargs=}, {best_config.num_stages=}, {best_config.num_warps=}")
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = metadata.sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = metadata.causal
ctx.alibi_slopes = metadata.alibi_slopes
ctx.dropout_p = metadata.dropout_p
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = metadata.return_encoded_softmax
return o if not metadata.return_encoded_softmax else (o, encoded_softmax, ) # S_dmask
@staticmethod
def backward(ctx, do, *args):
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, M = ctx.saved_tensors
import os
if os.environ.get("TRITON_FLASHATTN_DEBUG", "0") == "1":
print(f"triton flash attention: {q.shape=}, {k.shape=}, {v.shape}, {o.shape=}, {do.shape=}")
print(f"triton flash attention: {q.stride()=}, {k.stride()=}, {v.stride()=}, {o.stride()=}, {do.stride()}")
# assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
seqlen_q = q.shape[2]
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
assert N_CTX % PRE_BLOCK == 0
delta = torch.empty_like(M)
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
padded_head = (Lk != ctx.BLOCK_DMODEL)
grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0])
_attn_bwd_preprocess[grid_preprocess](
o, do, delta,
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
seqlen_q,
head_dim=Lk,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
grid = lambda META: (
triton.cdiv(N_CTX, META['BLOCK_N1']),
1,
BATCH * N_HEAD
)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, ctx.alibi_slopes, do, dq, dk, dv,
M, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
N_HEAD, N_CTX,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
USE_ALIBI= False if ctx.alibi_slopes is None else True,
)
return dq, dk, dv, None, None
attention = _attention.apply
# flash_attn wrapper
def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype):
torch.manual_seed(20)
# Initialize q, k, v
q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
return q, k, v, input_metadata
def padding_bshd(t): # BSHD
batch, seqlen, nheads, dim = t.shape
t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads*dim), (0, 32), 'constant', 0)[:,:,:-32].reshape(batch, seqlen, nheads, dim) # pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:,:,:-32] # pad: dim+32
return t
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False, padding_input=False):
if padding_input:
k, v = (padding_bshd(t) for t in (k, v))
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.max_seqlens_q = q.shape[2]
input_metadata.max_seqlens_k = k.shape[2]
return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2)
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False, padding_input=False):
k, v = kv[:, :, 0], kv[:, :, 1] # batch_size, seqlen, 2, nheads_k, headdim
if padding_input:
k, v = (padding_bshd(t) for t in (k, v)) # pad
q, k, v = (t.transpose(1, 2) for t in (q, k, v)) # trans bshd to bhsd
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.max_seqlens_q = q.shape[2]
input_metadata.max_seqlens_k = k.shape[2]
return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2) # trans bhsd to bshd
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False, padding_input=False):
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
if padding_input:
k, v = (padding_bshd(t) for t in (k, v))
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.max_seqlens_q = q.shape[2]
input_metadata.max_seqlens_k = k.shape[2]
return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2)
# varlen flash_attn
def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal):
torch.manual_seed(20)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
max_seqlens_q = torch.max(seqlens_q).item()
max_seqlens_k = torch.max(seqlens_k).item()
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)])
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)])
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# -1 because the last entry of cu_seqlens_q specifies the end of the last seq
num_ctxs = len(cu_seqlens_q) - 1
# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
input_metadata.max_seqlens_q = max_seqlens_q
input_metadata.max_seqlens_k = max_seqlens_k
if causal:
input_metadata.need_causal()
return q, k, v, input_metadata
def padding_thd(t): # THD
total_seqlen, nheads, dim = t.shape
t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads*dim), (0, 32), 'constant', 0)[:,:-32].reshape(total_seqlen, nheads, dim) # pad: nheads*dim+32
# t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:-32] # pad: dim+32
return t
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p=0.0, softmax_scale=None,
causal=False, return_attn_probs=False, padding_input=False):
q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # total_seqlen, 3, nheads, dim
if padding_input:
k, v = (padding_thd(t) for t in (k, v)) # pad
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.set_varlen_params(cu_seqlens, cu_seqlens)
input_metadata.max_seqlens_q = max_seqlens
input_metadata.max_seqlens_k = max_seqlens
return _attention.apply(q, k, v, None, input_metadata)
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False, padding_input=False):
k, v = kv[:, 0], kv[:, 1] # total_seqlen, 2, nheads, dim
if padding_input:
k, v = (padding_thd(t) for t in (k, v))
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
input_metadata.max_seqlens_q = max_seqlens_q
input_metadata.max_seqlens_k = max_seqlens_k
return _attention.apply(q, k, v, None, input_metadata)
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False, padding_input=False):
if padding_input:
k, v = (padding_thd(t) for t in (k, v))
softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5
input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
input_metadata.max_seqlens_q = max_seqlens_q
input_metadata.max_seqlens_k = max_seqlens_k
return _attention.apply(q, k, v, None, input_metadata)
# legacy interface
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False, padding_input=False):
return flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p, softmax_scale,
causal, return_attn_probs, padding_input)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, padding_input=False):
return flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p, softmax_scale, causal, return_attn_probs, padding_input)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False, padding_input=False):
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k,
dropout_p, softmax_scale, causal, return_attn_probs, padding_input)
......@@ -5,6 +5,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
import vllm.envs as envs
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
......@@ -122,26 +123,53 @@ class PagedAttention:
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
......@@ -156,29 +184,61 @@ class PagedAttention:
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output
@staticmethod
......
......@@ -684,7 +684,7 @@ if triton.__version__ >= "2.1.0":
sliding_window=None):
cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
BLOCK = 32 if cap[0] >= 8 else 32
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
......@@ -701,7 +701,7 @@ if triton.__version__ >= "2.1.0":
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
num_warps = 8 if Lk <= 64 else 8
num_warps = 8 if Lk <= 64 else 4
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
......
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import numpy as np
import torch
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def run_vllm(
warmup_requests: List[Tuple[str, int, int]],
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
)
# Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
# warmup
warmup_prompts = []
warmup_sampling_params = []
for prompt, _, output_len in warmup_requests:
warmup_prompts.append(prompt)
warmup_sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True)
# dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.num_prompts,
# args.input_len))
# dummy_inputs: List[PromptStrictInputs] = [{
# "prompt_token_ids": batch
# } for batch in dummy_prompt_token_ids.tolist()]
# def run_to_completion():
# llm.generate(dummy_inputs,
# sampling_params=sampling_params,
# use_tqdm=False)
# print("Warming up...")
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion()
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
end = time.perf_counter()
client = client(model)
client.terminate_server()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
warmup_prompt = "hi" * 10
warmup_requests = [(warmup_prompt, 10, 10)
for _ in range(1)]
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
if args.dataset is None:
total_out_tokens = args.output_len * args.num_prompts
else:
total_out_tokens = sum(output_len for _, _, output_len in requests)
print(f"Latency: {elapsed_time:.2f} s")
print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args)
\ No newline at end of file
......@@ -173,7 +173,7 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
rocm_supported_quantization = ["gptq", "squeezellm","awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......@@ -279,6 +279,12 @@ class ModelConfig:
return self.hf_text_config.hidden_size
def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
......
......@@ -315,4 +315,4 @@ class CustomAllreduce:
self._ptr = 0
def __del__(self):
self.close()
self.close()
\ No newline at end of file
......@@ -232,76 +232,91 @@ class LLMEngine:
load_config=load_config,
)
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
init_success = False
try:
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
tokenizer_group = self.get_tokenizer_group()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return tokenizer_group.get_lora_tokenizer(
sequence.lora_request)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -393,10 +408,10 @@ class LLMEngine:
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
# def get_tokenizer_for_seq(self,
# sequence: Sequence) -> "PreTrainedTokenizer":
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
......@@ -785,7 +800,8 @@ class LLMEngine:
# Log stats.
self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# if not request_outputs:
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
......
......@@ -9,6 +9,9 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
......@@ -27,7 +30,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
......@@ -131,7 +134,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
......@@ -145,7 +163,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "120")),
# API key for VLLM API server
"VLLM_API_KEY":
......@@ -214,15 +232,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
# Timeout for fetching images when serving multimodal models
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
}
# end-env-vars-definition
......@@ -236,4 +252,4 @@ def __getattr__(name):
def __dir__():
return list(environment_variables.keys())
return list(environment_variables.keys())
\ No newline at end of file
......@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
......@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
......@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
# logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment