Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
import argparse
import copy
import json
import math
import os
from pathlib import Path
from typing import Any, List, Optional, Tuple
import matplotlib.pyplot as plt
import pandas as pd
## JSON parsing utils ####
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max([
largest_dist_from_leaf(child, depth=depth + 1)
for child in node["children"]
])
def get_entries_at_depth(depth: int,
entries_and_traces: List[Tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
# The tree is not tall enough!
entries_and_traces.append((node["entry"], trace))
return
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"], ) + trace
for child in node["children"]:
get_entries_at_depth(depth,
entries_and_traces,
child,
curr_depth=curr_depth + 1,
trace=trace)
def fold_nodes(root: dict, nodes_to_fold: List[str]):
stack: List[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
stack.append(child)
return root
## Operation name cleanup utils ####
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
if len(string) > 3:
string = string + "..."
return string
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(
trim_string_back(abbreviate_known_names(t.get_text()),
max_char_len))
def abbreviate_known_names(name: str) -> str:
abbreviations = {
"MergedColumnParallelLinear": "MCPLinear",
"QKVParallelLinear": "QKVPLinear",
"RowParallelLinear": "RPLinear",
"weight=": "w=",
"bfloat16": "bf16",
"float16": "f16",
}
for key, value in abbreviations.items():
name = name.replace(key, value)
return name
def attempt_to_make_names_unique(entries_and_traces):
names, non_unique_names = (set(), set())
def all_the_same(items) -> bool:
return all(i == items[0] for i in items)
for entry, _ in entries_and_traces:
if entry["name"] in names:
non_unique_names.add(entry["name"])
else:
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name]
zipped_traces = list(
zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(i for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)), None)
if first_trace_difference is None:
# can't create a unique name, leave them names as the
# are they will get aggregated by the pivot_table call
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join((entry["name"], ) +
trace[:first_trace_difference + 1])
## Operation grouping utils ####
'''
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
'''
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or \
"reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or \
"scaled_int8_quant" in op_name:
return True
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
if "xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name or \
"s16816gemm" in op_name:
return True
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or \
"memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
# nccl ops
def is_nccl_op(op_name: str):
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
return "cross_device_reduce_1stage" in op_name
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name
def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
headers = list(trace_df)
ops = copy.deepcopy(headers)
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
ops = list(filter(lambda x: x not in attention_ops, ops))
quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops))
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
ops = list(filter(lambda x: x not in mem_ops, ops))
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
ops = list(filter(lambda x: x not in elementwise_ops, ops))
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
axis=1)
if len(mem_ops):
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df
## Data plotting utils ####
def plot_trace_df(traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None):
phases = traces_df['phase'].unique()
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
aggfunc="sum")
traces_df = group_trace_by_operations(traces_df)
# Make the figure
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
# Draw the stacked bars
ops = list(traces_df)
bottom = [0] * len(phases)
for op in ops:
values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phases, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha='center',
color='w',
weight='bold',
size=5)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles,
labels,
loc='center left',
bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches='tight')
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: List[str]):
def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: List[Tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
root = fold_nodes(root, json_nodes_to_fold)
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(df: pd.DataFrame,
metric: str,
top_k: int = 9) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
["name"]] = "others"
return df
# Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
# Attempt some cleanup
if make_names_unique:
for trace in traces:
attempt_to_make_names_unique(trace)
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
traces))
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k), trace_dfs))
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
# Add a derived metric `cuda_time_ms`
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
traces_df = traces_df.fillna(0)
return traces_df
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
return (f"{context['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"OutputLen={context['output_len']},"
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}")
profile_json = None
with open(json_trace, "r") as f:
profile_json = json.load(f)
assert profile_json is not None
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert (step_traces[0] == 'context')
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[::args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
prefill_traces = prepare_data(profile_json, prefills)
decode_traces = prepare_data(profile_json, decodes)
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
output_directory / Path("prefill.png"))
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_profile.py")
parser.add_argument("--output-directory",
type=str,
required=False,
help="Directory to output plots")
parser.add_argument("--level",
type=str,
default="module",
choices=["module", "kernel"])
parser.add_argument("--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.")
parser.add_argument("--fold-json-node",
nargs='+',
default=['Sampler', 'LogitsProcessor'],
help='Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children')
parser.add_argument("--plot-metric",
type=str,
default="cuda_time_ms",
help='Metric to plot. some options are cuda_time_ms, \
pct_cuda_time')
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step")
args = parser.parse_args()
# Prepare/Extract relevant args
make_names_unique = False
if args.level == "module":
depth = -2
make_names_unique = True
elif args.level == "kernel":
depth = -1
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = args.output_directory if args.output_directory else Path(
args.json_trace).parent
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
make_names_unique, args.top_k, args.fold_json_node)
#!/usr/bin/env python3
# Copyright (c) 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py
"""Summarize the last ninja build, invoked with ninja's -C syntax.
> python3 tools/report_build_time_ninja.py -C build/..
Typical output looks like this:
```
Longest build steps for .cpp.o:
1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time)
2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time)
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
Longest build steps for .cu.o:
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time)
15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time)
37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time)
43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time)
344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time)
1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism)
134 build steps completed, average of 0.12/s
```
"""
import argparse
import errno
import fnmatch
import os
import sys
from collections import defaultdict
# The number of long build times to report:
long_count = 10
# The number of long times by extension to report
long_ext_count = 10
class Target:
"""Represents a single line read for a .ninja_log file."""
def __init__(self, start, end):
"""Creates a target object by passing in the start/end times in seconds
as a float."""
self.start = start
self.end = end
# A list of targets, appended to by the owner of this object.
self.targets = []
self.weighted_duration = 0.0
def Duration(self):
"""Returns the task duration in seconds as a float."""
return self.end - self.start
def SetWeightedDuration(self, weighted_duration):
"""Sets the duration, in seconds, passed in as a float."""
self.weighted_duration = weighted_duration
def WeightedDuration(self):
"""Returns the task's weighted duration in seconds as a float.
Weighted_duration takes the elapsed time of the task and divides it
by how many other tasks were running at the same time. Thus, it
represents the approximate impact of this task on the total build time,
with serialized or serializing steps typically ending up with much
longer weighted durations.
weighted_duration should always be the same or shorter than duration.
"""
# Allow for modest floating-point errors
epsilon = 0.000002
if (self.weighted_duration > self.Duration() + epsilon):
print('%s > %s?' % (self.weighted_duration, self.Duration()))
assert (self.weighted_duration <= self.Duration() + epsilon)
return self.weighted_duration
def DescribeTargets(self):
"""Returns a printable string that summarizes the targets."""
# Some build steps generate dozens of outputs - handle them sanely.
# The max_length was chosen so that it can fit most of the long
# single-target names, while minimizing word wrapping.
result = ', '.join(self.targets)
max_length = 65
if len(result) > max_length:
result = result[:max_length] + '...'
return result
# Copied with some modifications from ninjatracing
def ReadTargets(log, show_all):
"""Reads all targets from .ninja_log file |log_file|, sorted by duration.
The result is a list of Target objects."""
header = log.readline()
assert header == '# ninja log v5\n', \
'unrecognized ninja log version %r' % header
targets_dict = {}
last_end_seen = 0.0
for line in log:
parts = line.strip().split('\t')
if len(parts) != 5:
# If ninja.exe is rudely halted then the .ninja_log file may be
# corrupt. Silently continue.
continue
start, end, _, name, cmdhash = parts # Ignore restat.
# Convert from integral milliseconds to float seconds.
start = int(start) / 1000.0
end = int(end) / 1000.0
if not show_all and end < last_end_seen:
# An earlier time stamp means that this step is the first in a new
# build, possibly an incremental build. Throw away the previous
# data so that this new build will be displayed independently.
# This has to be done by comparing end times because records are
# written to the .ninja_log file when commands complete, so end
# times are guaranteed to be in order, but start times are not.
targets_dict = {}
target = None
if cmdhash in targets_dict:
target = targets_dict[cmdhash]
if not show_all and (target.start != start or target.end != end):
# If several builds in a row just run one or two build steps
# then the end times may not go backwards so the last build may
# not be detected as such. However in many cases there will be a
# build step repeated in the two builds and the changed
# start/stop points for that command, identified by the hash,
# can be used to detect and reset the target dictionary.
targets_dict = {}
target = None
if not target:
targets_dict[cmdhash] = target = Target(start, end)
last_end_seen = end
target.targets.append(name)
return list(targets_dict.values())
def GetExtension(target, extra_patterns):
"""Return the file extension that best represents a target.
For targets that generate multiple outputs it is important to return a
consistent 'canonical' extension. Ultimately the goal is to group build steps
by type."""
for output in target.targets:
if extra_patterns:
for fn_pattern in extra_patterns.split(';'):
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
return fn_pattern
# Not a true extension, but a good grouping.
if output.endswith('type_mappings'):
extension = 'type_mappings'
break
# Capture two extensions if present. For example: file.javac.jar should
# be distinguished from file.interface.jar.
root, ext1 = os.path.splitext(output)
_, ext2 = os.path.splitext(root)
extension = ext2 + ext1 # Preserve the order in the file name.
if len(extension) == 0:
extension = '(no extension found)'
if ext1 in ['.pdb', '.dll', '.exe']:
extension = 'PEFile (linking)'
# Make sure that .dll and .exe are grouped together and that the
# .dll.lib files don't cause these to be listed as libraries
break
if ext1 in ['.so', '.TOC']:
extension = '.so (linking)'
# Attempt to identify linking, avoid identifying as '.TOC'
break
# Make sure .obj files don't get categorized as mojo files
if ext1 in ['.obj', '.o']:
break
# Jars are the canonical output of java targets.
if ext1 == '.jar':
break
# Normalize all mojo related outputs to 'mojo'.
if output.count('.mojom') > 0:
extension = 'mojo'
break
return extension
def SummarizeEntries(entries, extra_step_types):
"""Print a summary of the passed in list of Target objects."""
# Create a list that is in order by time stamp and has entries for the
# beginning and ending of each build step (one time stamp may have multiple
# entries due to multiple steps starting/stopping at exactly the same time).
# Iterate through this list, keeping track of which tasks are running at all
# times. At each time step calculate a running total for weighted time so
# that when each task ends its own weighted time can easily be calculated.
task_start_stop_times = []
earliest = -1
latest = 0
total_cpu_time = 0
for target in entries:
if earliest < 0 or target.start < earliest:
earliest = target.start
if target.end > latest:
latest = target.end
total_cpu_time += target.Duration()
task_start_stop_times.append((target.start, 'start', target))
task_start_stop_times.append((target.end, 'stop', target))
length = latest - earliest
weighted_total = 0.0
# Sort by the time/type records and ignore |target|
task_start_stop_times.sort(key=lambda times: times[:2])
# Now we have all task start/stop times sorted by when they happen. If a
# task starts and stops on the same time stamp then the start will come
# first because of the alphabet, which is important for making this work
# correctly.
# Track the tasks which are currently running.
running_tasks = {}
# Record the time we have processed up to so we know how to calculate time
# deltas.
last_time = task_start_stop_times[0][0]
# Track the accumulated weighted time so that it can efficiently be added
# to individual tasks.
last_weighted_time = 0.0
# Scan all start/stop events.
for event in task_start_stop_times:
time, action_name, target = event
# Accumulate weighted time up to now.
num_running = len(running_tasks)
if num_running > 0:
# Update the total weighted time up to this moment.
last_weighted_time += (time - last_time) / float(num_running)
if action_name == 'start':
# Record the total weighted task time when this task starts.
running_tasks[target] = last_weighted_time
if action_name == 'stop':
# Record the change in the total weighted task time while this task
# ran.
weighted_duration = last_weighted_time - running_tasks[target]
target.SetWeightedDuration(weighted_duration)
weighted_total += weighted_duration
del running_tasks[target]
last_time = time
assert (len(running_tasks) == 0)
# Warn if the sum of weighted times is off by more than half a second.
if abs(length - weighted_total) > 500:
print('Warning: Possible corrupt ninja log, results may be '
'untrustworthy. Length = %.3f, weighted total = %.3f' %
(length, weighted_total))
entries_by_ext = defaultdict(list)
for target in entries:
extension = GetExtension(target, extra_step_types)
entries_by_ext[extension].append(target)
for key, values in entries_by_ext.items():
print(' Longest build steps for %s:' % key)
values.sort(key=lambda x: x.WeightedDuration())
for target in values[-long_count:]:
print(' %8.1f weighted s to build %s (%.1f s elapsed time)' %
(target.WeightedDuration(), target.DescribeTargets(),
target.Duration()))
print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx '
'parallelism)' %
(length, total_cpu_time, total_cpu_time * 1.0 / length))
print(' %d build steps completed, average of %1.2f/s' %
(len(entries), len(entries) / (length)))
def main():
log_file = '.ninja_log'
parser = argparse.ArgumentParser()
parser.add_argument('-C', dest='build_directory', help='Build directory.')
parser.add_argument(
'-s',
'--step-types',
help='semicolon separated fnmatch patterns for build-step grouping')
parser.add_argument('--log-file',
help="specific ninja log file to analyze.")
args, _extra_args = parser.parse_known_args()
if args.build_directory:
log_file = os.path.join(args.build_directory, log_file)
if args.log_file:
log_file = args.log_file
if args.step_types:
# Make room for the extra build types.
global long_ext_count
long_ext_count += len(args.step_types.split(';'))
try:
with open(log_file, 'r') as log:
entries = ReadTargets(log, False)
SummarizeEntries(entries, args.step_types)
except IOError:
print('Log file %r not found, no build summary created.' % log_file)
return errno.ENOENT
if __name__ == '__main__':
sys.exit(main())
......@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
......@@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
......
import contextlib
import functools
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.library
import vllm.envs as envs
from vllm._core_ext import ScalarType
......@@ -30,6 +31,16 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
def hint_on_error(fn):
......@@ -37,6 +48,15 @@ def hint_on_error(fn):
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except NotImplementedError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)")
logger.error(msg, fn.__name__, e)
raise NotImplementedError(msg % (fn.__name__, e)) from e
except AttributeError as e:
msg = (
"Error in calling custom op %s: %s\n"
......@@ -211,7 +231,7 @@ def paged_attention_v2_opt(
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
# page attention ops (opt_tc)
def paged_attention_v1_opt_tc(
out: torch.Tensor,
query: torch.Tensor,
......@@ -453,7 +473,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_gemm"):
@torch.library.register_fake("_C::gptq_gemm")
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
......@@ -489,7 +509,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
......@@ -497,7 +517,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::gptq_marlin_gemm")
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
......@@ -514,12 +534,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::ggml_dequantize")
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
......@@ -528,7 +548,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_a8")
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
......@@ -538,7 +558,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::marlin_qqq_gemm")
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
......@@ -548,7 +568,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::marlin_gemm")
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
......@@ -557,7 +577,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::awq_dequantize")
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
......@@ -568,7 +588,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=scales.dtype,
device=scales.device)
@torch.library.register_fake("_C::awq_gemm")
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
......@@ -577,7 +597,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=input.dtype,
device=input.device).sum(0)
@torch.library.register_fake("_C::aqlm_gemm")
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
......@@ -593,7 +613,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))
@torch.library.register_fake("_C::aqlm_dequant")
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
......@@ -603,14 +623,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=codebooks.dtype,
device=codebooks.device)
@torch.library.register_fake("_C::fp8_marlin_gemm")
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@torch.library.register_fake("_C::machete_gemm")
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
......@@ -628,41 +648,45 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::machete_prepack_B")
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@torch.library.register_fake("_C::causal_conv1d_fwd")
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if z_ is not None:
c = torch.empty_like(z_)
return [a, c]
else:
return [a]
conv_states: Optional[torch.Tensor],
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
return None
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
# cutlass
......@@ -756,6 +780,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output
def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
......@@ -813,7 +851,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "permute_cols"):
@torch.library.register_fake("_C::permute_cols")
@register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
......@@ -959,37 +997,41 @@ def ggml_mul_mat_a8(
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
initial_states_, final_states_out_,
silu_activation)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation,
conv_state_indices)
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int):
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, cache_seqlens,
conv_state_indices, pad_slot_id)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
delta_bias_, delta_softplus, index_,
x)
delta_softplus: bool,
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: torch.Tensor, pad_slot_id: int):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, query_start_loc,
cache_indices, has_initial_state,
ssm_states, pad_slot_id)
# moe
......@@ -1011,16 +1053,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe")
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, num_experts: int,
topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
......
......@@ -186,6 +186,9 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Max number of query tokens for among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
......@@ -357,6 +360,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
......@@ -373,7 +378,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
......@@ -399,7 +404,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert kv_cache is None \
assert kv_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention."
......
......@@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
# yapf: disable
from vllm.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)
# yapf: enable
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)
@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)
@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)
@torch.library.custom_op("vllm::reshape_and_cache_flash",
mutates_args=["kv_cache"])
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
k_scale, v_scale)
@reshape_and_cache_flash.register_fake # type: ignore
def _(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
pass
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
class FlashAttentionBackend(AttentionBackend):
......@@ -245,8 +108,12 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
......@@ -305,6 +172,7 @@ class FlashAttentionMetadata(AttentionMetadata):
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
......@@ -331,20 +199,27 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
query_start_loc=self.query_start_loc[self.num_prefills:]
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
......@@ -355,6 +230,23 @@ class FlashAttentionMetadata(AttentionMetadata):
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
......@@ -366,7 +258,6 @@ class FlashAttentionMetadata(AttentionMetadata):
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
......@@ -414,8 +305,6 @@ class FlashAttentionMetadataBuilder(
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
......@@ -441,9 +330,6 @@ class FlashAttentionMetadataBuilder(
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
......@@ -467,13 +353,37 @@ class FlashAttentionMetadataBuilder(
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
......@@ -498,33 +408,22 @@ class FlashAttentionMetadataBuilder(
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
......@@ -566,6 +465,7 @@ class FlashAttentionMetadataBuilder(
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
......@@ -665,6 +565,8 @@ class FlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
......@@ -679,106 +581,198 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops.vllm.reshape_and_cache_flash(
key,
value,
kv_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
output = torch.ops.vllm.unified_flash_attention(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1:
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
......@@ -7,7 +7,7 @@ try:
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
import vllm.attention.backends.flash_attn # noqa
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
......@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
......@@ -410,18 +411,22 @@ class FlashInferMetadata(AttentionMetadata):
return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
......@@ -470,8 +475,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
......@@ -537,9 +540,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping.
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
......@@ -591,7 +594,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
......@@ -630,7 +632,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
......@@ -746,7 +747,7 @@ class FlashInferImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
......@@ -759,73 +760,132 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
return torch.ops.vllm.unified_flash_infer(
query,
key,
value,
self.num_heads,
self.head_size,
self.num_kv_heads,
kv_cache,
self.kv_cache_dtype,
k_scale,
v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
)
@torch.library.custom_op("vllm::unified_flash_infer",
mutates_args=["kv_cache"])
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous() # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=self.logits_soft_cap,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=self.scale,
logits_soft_cap=self.logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size)
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(
query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size)
@unified_flash_infer.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
......@@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
......@@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
......@@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache(
......@@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
......
......@@ -9,6 +9,31 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.attention.backends.utils import CommonAttentionState
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
src_offset: int, dst_offset: int) -> None:
def create_roi_tensor(
tensor: ov.Tensor,
block_number: int,
) -> ov.Tensor:
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
roi_end = ov.runtime.Coordinate(tensor.get_shape())
roi_begin[0] = block_number
roi_end[0] = block_number + 1
if isinstance(tensor, ov.Tensor):
return ov.Tensor(tensor, roi_begin, roi_end)
else:
return ov.RemoteTensor(tensor, roi_begin, roi_end)
src_roi_tensor = \
create_roi_tensor(src_tensor, src_offset)
dst_roi_tensor = \
create_roi_tensor(dst_tensor, dst_offset)
src_roi_tensor.copy_to(dst_roi_tensor)
class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
......@@ -44,13 +69,12 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
src_tensor: ov.Tensor,
dst_tensor: ov.Tensor,
src_to_dists: List[Tuple[int, int]],
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError
for src, dst in src_to_dists:
copy_cache_block(src_tensor, dst_tensor, src, dst)
@staticmethod
def copy_blocks(
......@@ -59,8 +83,8 @@ class OpenVINOAttentionBackend(AttentionBackend):
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
copy_cache_block(key_cache, key_cache, src, dst)
copy_cache_block(value_cache, value_cache, src, dst)
@dataclass
......
......@@ -130,7 +130,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert tpu_type is not None
tpu_type = tpu_type.lower()
if "lite" not in tpu_type:
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
......@@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
......@@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
......@@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0] is not None:
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
......@@ -205,36 +207,55 @@ class PallasAttentionBackendImpl(AttentionImpl):
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = self.megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
self.megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
)
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
......@@ -256,3 +277,43 @@ def write_to_kv_cache(
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""
@staticmethod
def get_name() -> str:
return "placeholder-attn"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl
@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder
@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return
@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.seq_start_loc is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
num_decode_tokens = batch_size
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class PlaceholderAttentionImpl(AttentionImpl):
def __init__(self, *args, **kwargs) -> None:
return
def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
......@@ -116,9 +116,14 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
......@@ -183,12 +188,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
)
return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
......@@ -398,10 +413,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
......@@ -414,7 +433,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
......@@ -451,7 +470,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
......
......@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
......@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
......@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if attn_type == AttentionType.DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
......@@ -151,7 +319,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
......@@ -164,88 +332,108 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TorchSDPABackendImpl")
num_tokens, hidden_size = query.shape
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, k_scale,
v_scale)
if attn_metadata.is_prompt:
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
output = self._run_sdpa_forward(query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_decode_seq_len,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......@@ -257,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
......
......@@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int,
use_v2_block_manager: bool):
context_len: int, sliding_window: int):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window)
return start_idx
......@@ -138,8 +132,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
......@@ -180,9 +172,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
......@@ -312,7 +304,8 @@ class CommonAttentionState(AttentionState):
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None,
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
......
......@@ -118,6 +118,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
......@@ -445,7 +448,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
......@@ -489,6 +492,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
......@@ -522,7 +527,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
......@@ -554,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
......@@ -580,15 +592,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
......
......@@ -42,10 +42,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
......@@ -76,9 +78,9 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
kv_cache_dtype, block_size,
is_attention_free, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
......@@ -90,7 +92,7 @@ class Attention(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
......
......@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
......@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
......@@ -105,9 +105,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
......@@ -146,23 +145,31 @@ def get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
raise ValueError("Invalid attention backend.")
def which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
# Check whether a particular choice of backend was
# previously forced.
#
......
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
length_penalty)
return sort_beams_key
......@@ -12,11 +12,11 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
......@@ -75,7 +75,6 @@ def run_vllm(
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
......@@ -89,11 +88,9 @@ def run_vllm(
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
......@@ -117,7 +114,6 @@ def run_vllm(
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
)
......@@ -129,13 +125,12 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
# warmup
warmup_prompts = []
warmup_sampling_params = []
......@@ -144,9 +139,8 @@ def run_vllm(
warmup_sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
......@@ -158,7 +152,7 @@ def run_vllm(
# dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.num_prompts,
# args.input_len))
# dummy_inputs: List[PromptInputs] = [{
# dummy_prompts: List[PromptType] = [{
# "prompt_token_ids": batch
# } for batch in dummy_prompt_token_ids.tolist()]
......@@ -171,22 +165,27 @@ def run_vllm(
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion()
if not use_new_beam_search_impl:
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter()
return end - start
......@@ -199,7 +198,6 @@ async def run_vllm_async(
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
......@@ -213,7 +211,6 @@ async def run_vllm_async(
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
......@@ -241,7 +238,6 @@ async def run_vllm_async(
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
......@@ -258,9 +254,8 @@ async def run_vllm_async(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
......@@ -282,11 +277,9 @@ def run_hf(
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
......@@ -318,7 +311,7 @@ def run_hf(
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
......@@ -378,40 +371,37 @@ def main(args: argparse.Namespace):
if args.async_engine:
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc
args.download_dir, args.load_format, args.disable_async_output_proc
]
else:
run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc
args.download_dir, args.load_format, args.disable_async_output_proc
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
args.hf_max_batch_size, args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
......@@ -473,12 +463,10 @@ if __name__ == "__main__":
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup',
type=int,
default=1,
help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
......@@ -543,9 +531,6 @@ if __name__ == "__main__":
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
......@@ -633,8 +618,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
......
__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment