Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
import argparse
import copy
import json
import math
import os
from pathlib import Path
from typing import Any, List, Optional, Tuple
import matplotlib.pyplot as plt
import pandas as pd
## JSON parsing utils ####
def largest_dist_from_leaf(node: dict, depth: int = 0):
if len(node["children"]) == 0:
return depth
return max([
largest_dist_from_leaf(child, depth=depth + 1)
for child in node["children"]
])
def get_entries_at_depth(depth: int,
entries_and_traces: List[Tuple[Any, Any]],
node: dict,
curr_depth: int = 0,
trace=()):
# assert that the query is at kernel or module level
assert depth == -1 or depth == -2
if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
# The tree is not tall enough!
entries_and_traces.append((node["entry"], trace))
return
if largest_dist_from_leaf(node) == (abs(depth) - 1):
entries_and_traces.append((node["entry"], trace))
trace = (node["entry"]["name"], ) + trace
for child in node["children"]:
get_entries_at_depth(depth,
entries_and_traces,
child,
curr_depth=curr_depth + 1,
trace=trace)
def fold_nodes(root: dict, nodes_to_fold: List[str]):
stack: List[dict] = [root]
while len(stack) != 0:
node = stack.pop()
if node['entry']['name'] in nodes_to_fold:
node["children"] = []
continue
for child in node["children"]:
stack.append(child)
return root
## Operation name cleanup utils ####
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
if len(string) > 3:
string = string + "..."
return string
def shorten_plot_legend_strings(legend, max_char_len: int):
for t in legend.get_texts():
t.set_text(
trim_string_back(abbreviate_known_names(t.get_text()),
max_char_len))
def abbreviate_known_names(name: str) -> str:
abbreviations = {
"MergedColumnParallelLinear": "MCPLinear",
"QKVParallelLinear": "QKVPLinear",
"RowParallelLinear": "RPLinear",
"weight=": "w=",
"bfloat16": "bf16",
"float16": "f16",
}
for key, value in abbreviations.items():
name = name.replace(key, value)
return name
def attempt_to_make_names_unique(entries_and_traces):
names, non_unique_names = (set(), set())
def all_the_same(items) -> bool:
return all(i == items[0] for i in items)
for entry, _ in entries_and_traces:
if entry["name"] in names:
non_unique_names.add(entry["name"])
else:
names.add(entry["name"])
for name in non_unique_names:
entries_and_traces_with_name = [(entry, trace)
for entry, trace in entries_and_traces
if entry["name"] == name]
zipped_traces = list(
zip(*[trace for _, trace in entries_and_traces_with_name]))
first_trace_difference = next(
(i for i, trace_eles in enumerate(zipped_traces)
if not all_the_same(trace_eles)), None)
if first_trace_difference is None:
# can't create a unique name, leave them names as the
# are they will get aggregated by the pivot_table call
continue
for entry, trace in entries_and_traces_with_name:
entry["name"] = " <- ".join((entry["name"], ) +
trace[:first_trace_difference + 1])
## Operation grouping utils ####
'''
Group operations in the given dataframe by some high-level ops like,
- gemms
- attention
- rms_norm
etc.
'''
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name:
return True
def is_attention_block(op_name: str):
if "flash_fwd" in op_name or \
"reshape_and_cache_flash_kernel" in op_name:
return True
def is_quant(op_name: str):
if "scaled_fp8_quant" in op_name or \
"scaled_int8_quant" in op_name:
return True
def is_gemm_op(op_name: str):
if is_quant(op_name):
return False
if "xmma_gemm" in op_name or \
"gemv2T_kernel" in op_name or \
"splitKreduce" in op_name or \
"void cutlass::Kernel" in op_name or \
"void cutlass::device_kernel" in op_name or \
"s16816gemm" in op_name:
return True
def is_elementwise_op(op_name: str):
return "elementwise_kernel" in op_name
def is_mem_op(op_name: str):
return "memcpy" in op_name.lower() or \
"memset" in op_name.lower()
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()
# nccl ops
def is_nccl_op(op_name: str):
return "nccl" in op_name.lower()
def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())
def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()
def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()
# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
return "cross_device_reduce_1stage" in op_name
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name
def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
headers = list(trace_df)
ops = copy.deepcopy(headers)
attention_ops = list(filter(lambda x: is_attention_block(x), ops))
ops = list(filter(lambda x: x not in attention_ops, ops))
quant_ops = list(filter(lambda x: is_quant(x), ops))
ops = list(filter(lambda x: x not in quant_ops, ops))
gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
ops = list(filter(lambda x: x not in gemm_ops, ops))
rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
ops = list(filter(lambda x: x not in rms_norm_ops, ops))
vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
ops = list(filter(lambda x: x not in vocab_embed_ops, ops))
mem_ops = list(filter(lambda x: is_mem_op(x), ops))
ops = list(filter(lambda x: x not in mem_ops, ops))
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
ops = list(filter(lambda x: x not in elementwise_ops, ops))
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
if len(gemm_ops):
trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
if len(rms_norm_ops):
trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
if len(vocab_embed_ops):
trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
axis=1)
if len(mem_ops):
trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
if len(elementwise_ops):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)
if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df
## Data plotting utils ####
def plot_trace_df(traces_df: pd.DataFrame,
plot_metric: str,
plot_title: str,
output: Optional[Path] = None):
phases = traces_df['phase'].unique()
traces_df = traces_df.pivot_table(index="phase",
columns="name",
values=plot_metric,
aggfunc="sum")
traces_df = group_trace_by_operations(traces_df)
# Make the figure
fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True)
# Draw the stacked bars
ops = list(traces_df)
bottom = [0] * len(phases)
for op in ops:
values = [traces_df[op][phase] for phase in phases]
values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
ax.bar(phases, values, label=op, bottom=bottom)
bottom = [bottom[j] + values[j] for j in range(len(phases))]
# Write the values as text on the bars
for bar in ax.patches:
if bar.get_height() != 0:
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height() / 2 + bar.get_y(),
f"{round(bar.get_height(), 2)}",
ha='center',
color='w',
weight='bold',
size=5)
# Setup legend
handles, labels = plt.gca().get_legend_handles_labels()
legend = fig.legend(handles,
labels,
loc='center left',
bbox_to_anchor=(1, 1))
shorten_plot_legend_strings(legend, 50)
# Setup labels and title
plt.setp(ax.get_xticklabels(), rotation=90)
ax.set_ylabel(plot_metric)
plt.suptitle(plot_title)
plt.savefig(output, bbox_inches='tight')
print("Created: ", output)
def main(
json_trace: Path,
output_directory: Path,
depth: int, # Fetch/Plot operations at this depth of the Json tree
plot_metric: str,
make_names_unique: bool,
top_k: int,
json_nodes_to_fold: List[str]):
def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame:
def get_entries_and_traces(key: str):
entries_and_traces: List[Tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]:
# Fold nodes in the traces as per user request. i.e. simply
# make the requested nodes leaf-nodes.
root = fold_nodes(root, json_nodes_to_fold)
get_entries_at_depth(depth, entries_and_traces, root)
return entries_and_traces
def keep_only_top_entries(df: pd.DataFrame,
metric: str,
top_k: int = 9) -> pd.DataFrame:
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
["name"]] = "others"
return df
# Get data for each key
traces = list(map(lambda x: get_entries_and_traces(x), step_keys))
# Attempt some cleanup
if make_names_unique:
for trace in traces:
attempt_to_make_names_unique(trace)
# To pandas dataframe
trace_dfs = list(
map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
traces))
# Respect top_k
if top_k:
trace_dfs = list(
map(
lambda trace_df: keep_only_top_entries(
trace_df, "cuda_time_us", top_k), trace_dfs))
# Fill in information about the step-keys
for trace_df, step_key in zip(trace_dfs, step_keys):
trace_df['phase'] = step_key
# Combine all data frames so they can be put in a single plot
traces_df = pd.concat(trace_dfs)
# Add a derived metric `cuda_time_ms`
traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
traces_df = traces_df.fillna(0)
return traces_df
def make_plot_title_suffix(profile_json: dict) -> str:
context = profile_json["context"]
sparsity = context.get('sparsity', None)
return (f"{context['model']}\n"
f"Batch={context['batch_size']}, "
f"PromptLen={context['prompt_len']}, "
f"OutputLen={context['output_len']},"
f"NumGpus={context['tensor_parallel_size']}"
f"{', Sparsity ' + sparsity if sparsity else ''}")
profile_json = None
with open(json_trace, "r") as f:
profile_json = json.load(f)
assert profile_json is not None
# Get all `llm.generate.step()` profile
step_traces = list(profile_json.keys())
assert (step_traces[0] == 'context')
step_traces = step_traces[1:] # have only prefill and decodes
prefills = list(filter(lambda x: "prefill" in x, step_traces))
all_decodes = list(filter(lambda x: "decode" in x, step_traces))
assert len(prefills) + len(all_decodes) == len(step_traces)
assert len(prefills) == 1
decodes = all_decodes[::args.step_plot_interval]
if decodes[-1] != all_decodes[-1]:
# Always have the last decode
decodes.append(all_decodes[-1])
prefill_traces = prepare_data(profile_json, prefills)
decode_traces = prepare_data(profile_json, decodes)
plot_title_suffix = make_plot_title_suffix(profile_json)
plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
output_directory / Path("prefill.png"))
plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
output_directory / Path("decode_steps.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--json-trace",
type=str,
required=True,
help="json trace file output by examples/offline_profile.py")
parser.add_argument("--output-directory",
type=str,
required=False,
help="Directory to output plots")
parser.add_argument("--level",
type=str,
default="module",
choices=["module", "kernel"])
parser.add_argument("--top-k",
type=int,
default=12,
help="Only graph the top `top_k` entries by time.")
parser.add_argument("--fold-json-node",
nargs='+',
default=['Sampler', 'LogitsProcessor'],
help='Do not plot the children of these nodes. Let, \
the node represent the aggregate of all its \
children')
parser.add_argument("--plot-metric",
type=str,
default="cuda_time_ms",
help='Metric to plot. some options are cuda_time_ms, \
pct_cuda_time')
parser.add_argument(
"--step-plot-interval",
type=int,
default=4,
help="For every `step_plot_interval` steps, plot 1 step")
args = parser.parse_args()
# Prepare/Extract relevant args
make_names_unique = False
if args.level == "module":
depth = -2
make_names_unique = True
elif args.level == "kernel":
depth = -1
else:
raise Exception(f"Unexpected level value ({args.level})")
output_directory = args.output_directory if args.output_directory else Path(
args.json_trace).parent
if not os.path.exists(output_directory):
os.makedirs(output_directory)
main(Path(args.json_trace), output_directory, depth, args.plot_metric,
make_names_unique, args.top_k, args.fold_json_node)
#!/usr/bin/env python3
# Copyright (c) 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py
"""Summarize the last ninja build, invoked with ninja's -C syntax.
> python3 tools/report_build_time_ninja.py -C build/..
Typical output looks like this:
```
Longest build steps for .cpp.o:
1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time)
2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time)
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
Longest build steps for .cu.o:
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time)
15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time)
15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time)
15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time)
37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time)
43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time)
344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time)
1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism)
134 build steps completed, average of 0.12/s
```
"""
import argparse
import errno
import fnmatch
import os
import sys
from collections import defaultdict
# The number of long build times to report:
long_count = 10
# The number of long times by extension to report
long_ext_count = 10
class Target:
"""Represents a single line read for a .ninja_log file."""
def __init__(self, start, end):
"""Creates a target object by passing in the start/end times in seconds
as a float."""
self.start = start
self.end = end
# A list of targets, appended to by the owner of this object.
self.targets = []
self.weighted_duration = 0.0
def Duration(self):
"""Returns the task duration in seconds as a float."""
return self.end - self.start
def SetWeightedDuration(self, weighted_duration):
"""Sets the duration, in seconds, passed in as a float."""
self.weighted_duration = weighted_duration
def WeightedDuration(self):
"""Returns the task's weighted duration in seconds as a float.
Weighted_duration takes the elapsed time of the task and divides it
by how many other tasks were running at the same time. Thus, it
represents the approximate impact of this task on the total build time,
with serialized or serializing steps typically ending up with much
longer weighted durations.
weighted_duration should always be the same or shorter than duration.
"""
# Allow for modest floating-point errors
epsilon = 0.000002
if (self.weighted_duration > self.Duration() + epsilon):
print('%s > %s?' % (self.weighted_duration, self.Duration()))
assert (self.weighted_duration <= self.Duration() + epsilon)
return self.weighted_duration
def DescribeTargets(self):
"""Returns a printable string that summarizes the targets."""
# Some build steps generate dozens of outputs - handle them sanely.
# The max_length was chosen so that it can fit most of the long
# single-target names, while minimizing word wrapping.
result = ', '.join(self.targets)
max_length = 65
if len(result) > max_length:
result = result[:max_length] + '...'
return result
# Copied with some modifications from ninjatracing
def ReadTargets(log, show_all):
"""Reads all targets from .ninja_log file |log_file|, sorted by duration.
The result is a list of Target objects."""
header = log.readline()
assert header == '# ninja log v5\n', \
'unrecognized ninja log version %r' % header
targets_dict = {}
last_end_seen = 0.0
for line in log:
parts = line.strip().split('\t')
if len(parts) != 5:
# If ninja.exe is rudely halted then the .ninja_log file may be
# corrupt. Silently continue.
continue
start, end, _, name, cmdhash = parts # Ignore restat.
# Convert from integral milliseconds to float seconds.
start = int(start) / 1000.0
end = int(end) / 1000.0
if not show_all and end < last_end_seen:
# An earlier time stamp means that this step is the first in a new
# build, possibly an incremental build. Throw away the previous
# data so that this new build will be displayed independently.
# This has to be done by comparing end times because records are
# written to the .ninja_log file when commands complete, so end
# times are guaranteed to be in order, but start times are not.
targets_dict = {}
target = None
if cmdhash in targets_dict:
target = targets_dict[cmdhash]
if not show_all and (target.start != start or target.end != end):
# If several builds in a row just run one or two build steps
# then the end times may not go backwards so the last build may
# not be detected as such. However in many cases there will be a
# build step repeated in the two builds and the changed
# start/stop points for that command, identified by the hash,
# can be used to detect and reset the target dictionary.
targets_dict = {}
target = None
if not target:
targets_dict[cmdhash] = target = Target(start, end)
last_end_seen = end
target.targets.append(name)
return list(targets_dict.values())
def GetExtension(target, extra_patterns):
"""Return the file extension that best represents a target.
For targets that generate multiple outputs it is important to return a
consistent 'canonical' extension. Ultimately the goal is to group build steps
by type."""
for output in target.targets:
if extra_patterns:
for fn_pattern in extra_patterns.split(';'):
if fnmatch.fnmatch(output, '*' + fn_pattern + '*'):
return fn_pattern
# Not a true extension, but a good grouping.
if output.endswith('type_mappings'):
extension = 'type_mappings'
break
# Capture two extensions if present. For example: file.javac.jar should
# be distinguished from file.interface.jar.
root, ext1 = os.path.splitext(output)
_, ext2 = os.path.splitext(root)
extension = ext2 + ext1 # Preserve the order in the file name.
if len(extension) == 0:
extension = '(no extension found)'
if ext1 in ['.pdb', '.dll', '.exe']:
extension = 'PEFile (linking)'
# Make sure that .dll and .exe are grouped together and that the
# .dll.lib files don't cause these to be listed as libraries
break
if ext1 in ['.so', '.TOC']:
extension = '.so (linking)'
# Attempt to identify linking, avoid identifying as '.TOC'
break
# Make sure .obj files don't get categorized as mojo files
if ext1 in ['.obj', '.o']:
break
# Jars are the canonical output of java targets.
if ext1 == '.jar':
break
# Normalize all mojo related outputs to 'mojo'.
if output.count('.mojom') > 0:
extension = 'mojo'
break
return extension
def SummarizeEntries(entries, extra_step_types):
"""Print a summary of the passed in list of Target objects."""
# Create a list that is in order by time stamp and has entries for the
# beginning and ending of each build step (one time stamp may have multiple
# entries due to multiple steps starting/stopping at exactly the same time).
# Iterate through this list, keeping track of which tasks are running at all
# times. At each time step calculate a running total for weighted time so
# that when each task ends its own weighted time can easily be calculated.
task_start_stop_times = []
earliest = -1
latest = 0
total_cpu_time = 0
for target in entries:
if earliest < 0 or target.start < earliest:
earliest = target.start
if target.end > latest:
latest = target.end
total_cpu_time += target.Duration()
task_start_stop_times.append((target.start, 'start', target))
task_start_stop_times.append((target.end, 'stop', target))
length = latest - earliest
weighted_total = 0.0
# Sort by the time/type records and ignore |target|
task_start_stop_times.sort(key=lambda times: times[:2])
# Now we have all task start/stop times sorted by when they happen. If a
# task starts and stops on the same time stamp then the start will come
# first because of the alphabet, which is important for making this work
# correctly.
# Track the tasks which are currently running.
running_tasks = {}
# Record the time we have processed up to so we know how to calculate time
# deltas.
last_time = task_start_stop_times[0][0]
# Track the accumulated weighted time so that it can efficiently be added
# to individual tasks.
last_weighted_time = 0.0
# Scan all start/stop events.
for event in task_start_stop_times:
time, action_name, target = event
# Accumulate weighted time up to now.
num_running = len(running_tasks)
if num_running > 0:
# Update the total weighted time up to this moment.
last_weighted_time += (time - last_time) / float(num_running)
if action_name == 'start':
# Record the total weighted task time when this task starts.
running_tasks[target] = last_weighted_time
if action_name == 'stop':
# Record the change in the total weighted task time while this task
# ran.
weighted_duration = last_weighted_time - running_tasks[target]
target.SetWeightedDuration(weighted_duration)
weighted_total += weighted_duration
del running_tasks[target]
last_time = time
assert (len(running_tasks) == 0)
# Warn if the sum of weighted times is off by more than half a second.
if abs(length - weighted_total) > 500:
print('Warning: Possible corrupt ninja log, results may be '
'untrustworthy. Length = %.3f, weighted total = %.3f' %
(length, weighted_total))
entries_by_ext = defaultdict(list)
for target in entries:
extension = GetExtension(target, extra_step_types)
entries_by_ext[extension].append(target)
for key, values in entries_by_ext.items():
print(' Longest build steps for %s:' % key)
values.sort(key=lambda x: x.WeightedDuration())
for target in values[-long_count:]:
print(' %8.1f weighted s to build %s (%.1f s elapsed time)' %
(target.WeightedDuration(), target.DescribeTargets(),
target.Duration()))
print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx '
'parallelism)' %
(length, total_cpu_time, total_cpu_time * 1.0 / length))
print(' %d build steps completed, average of %1.2f/s' %
(len(entries), len(entries) / (length)))
def main():
log_file = '.ninja_log'
parser = argparse.ArgumentParser()
parser.add_argument('-C', dest='build_directory', help='Build directory.')
parser.add_argument(
'-s',
'--step-types',
help='semicolon separated fnmatch patterns for build-step grouping')
parser.add_argument('--log-file',
help="specific ninja log file to analyze.")
args, _extra_args = parser.parse_known_args()
if args.build_directory:
log_file = os.path.join(args.build_directory, log_file)
if args.log_file:
log_file = args.log_file
if args.step_types:
# Make room for the extra build types.
global long_ext_count
long_ext_count += len(args.step_types.split(';'))
try:
with open(log_file, 'r') as log:
entries = ReadTargets(log, False)
SummarizeEntries(entries, args.step_types)
except IOError:
print('Log file %r not found, no build summary created.' % log_file)
return errno.ENOENT
if __name__ == '__main__':
sys.exit(main())
...@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
...@@ -19,7 +19,7 @@ __all__ = [ ...@@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__", "__version_tuple__",
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptInputs", "PromptType",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"SamplingParams", "SamplingParams",
......
import contextlib import contextlib
import functools import functools
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import torch.library
import vllm.envs as envs import vllm.envs as envs
from vllm._core_ext import ScalarType from vllm._core_ext import ScalarType
...@@ -30,6 +31,16 @@ with contextlib.suppress(ImportError): ...@@ -30,6 +31,16 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401 import vllm._moe_C # noqa: F401
supports_moe_ops = True supports_moe_ops = True
if TYPE_CHECKING:
def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake
def hint_on_error(fn): def hint_on_error(fn):
...@@ -37,6 +48,15 @@ def hint_on_error(fn): ...@@ -37,6 +48,15 @@ def hint_on_error(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
except NotImplementedError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)")
logger.error(msg, fn.__name__, e)
raise NotImplementedError(msg % (fn.__name__, e)) from e
except AttributeError as e: except AttributeError as e:
msg = ( msg = (
"Error in calling custom op %s: %s\n" "Error in calling custom op %s: %s\n"
...@@ -211,7 +231,7 @@ def paged_attention_v2_opt( ...@@ -211,7 +231,7 @@ def paged_attention_v2_opt(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt) # page attention ops (opt_tc)
def paged_attention_v1_opt_tc( def paged_attention_v1_opt_tc(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -453,7 +473,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -453,7 +473,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_gemm"): if hasattr(torch.ops._C, "gptq_gemm"):
@torch.library.register_fake("_C::gptq_gemm") @register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
...@@ -489,7 +509,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -489,7 +509,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@torch.library.register_fake("_C::gptq_marlin_24_gemm") @register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, workspace: torch.Tensor,
...@@ -497,7 +517,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -497,7 +517,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_n: int, size_k: int) -> torch.Tensor: size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::gptq_marlin_gemm") @register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor, def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_scales: torch.Tensor,
...@@ -514,12 +534,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -514,12 +534,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False) -> torch.Tensor: use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::ggml_dequantize") @register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor: n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device) return torch.empty((m, n), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8") @register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake( def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor, W: torch.Tensor,
X: torch.Tensor, X: torch.Tensor,
...@@ -528,7 +548,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -528,7 +548,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device) return torch.empty((1, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_a8") @register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake( def _ggml_mul_mat_a8_fake(
W: torch.Tensor, W: torch.Tensor,
X: torch.Tensor, X: torch.Tensor,
...@@ -538,7 +558,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -538,7 +558,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
batch = X.size(0) batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device) return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::marlin_qqq_gemm") @register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor,
...@@ -548,7 +568,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -548,7 +568,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16, dtype=torch.float16,
device=a.device) device=a.device)
@torch.library.register_fake("_C::marlin_gemm") @register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_m: int, size_n: int,
...@@ -557,7 +577,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -557,7 +577,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=torch.float16, dtype=torch.float16,
device=a.device) device=a.device)
@torch.library.register_fake("_C::awq_dequantize") @register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
...@@ -568,7 +588,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -568,7 +588,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=scales.dtype, dtype=scales.dtype,
device=scales.device) device=scales.device)
@torch.library.register_fake("_C::awq_gemm") @register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor: split_k_iters: int) -> torch.Tensor:
...@@ -577,7 +597,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -577,7 +597,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=input.dtype, dtype=input.dtype,
device=input.device).sum(0) device=input.device).sum(0)
@torch.library.register_fake("_C::aqlm_gemm") @register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int], codebook_partition_sizes: List[int],
...@@ -593,7 +613,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -593,7 +613,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
output_sizes.append(-1) output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes)) return flat_output.reshape(tuple(output_sizes))
@torch.library.register_fake("_C::aqlm_dequant") @register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake( def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor: codebook_partition_sizes: List[int]) -> torch.Tensor:
...@@ -603,14 +623,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -603,14 +623,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=codebooks.dtype, dtype=codebooks.dtype,
device=codebooks.device) device=codebooks.device)
@torch.library.register_fake("_C::fp8_marlin_gemm") @register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor: size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@torch.library.register_fake("_C::machete_gemm") @register_fake("_C::machete_gemm")
def machete_gemm_fake( def machete_gemm_fake(
a: torch.Tensor, a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B # Should be the tensor returned by machete_prepack_B
...@@ -628,41 +648,45 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -628,41 +648,45 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
n = b_q.size(1) n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype) return torch.empty((m, n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::machete_prepack_B") @register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor, def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor: b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight, return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format) memory_format=torch.contiguous_format)
@torch.library.register_fake("_C::causal_conv1d_fwd") @register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor], cu_seq_len: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor: has_initial_state: Optional[torch.Tensor],
return torch.empty_like(x) silu_activation: bool, pad_slot_id: int):
return None
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake( @register_fake("_C::causal_conv1d_update")
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool, weight: torch.Tensor,
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: bias_: Optional[torch.Tensor],
return torch.empty_like(x) silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
@torch.library.register_fake("_C::selective_scan_fwd") conv_state_indices: Optional[torch.Tensor],
def selective_scan_fwd_fake( pad_slot_id: int) -> None:
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, return None
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], @register_fake("_C::selective_scan_fwd")
delta_softplus: bool, index_: Optional[torch.Tensor], def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
x: Optional[torch.Tensor]) -> List[torch.Tensor]: A: torch.Tensor, B: torch.Tensor,
a = torch.empty_like(u) C: torch.Tensor, D_: Optional[torch.Tensor],
if z_ is not None: z_: Optional[torch.Tensor],
c = torch.empty_like(z_) delta_bias_: Optional[torch.Tensor],
return [a, c] delta_softplus: bool,
else: cu_seq_len: Optional[torch.Tensor],
return [a] cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
# cutlass # cutlass
...@@ -756,6 +780,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -756,6 +780,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output return output
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output
def gptq_marlin_gemm(a: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_scales: torch.Tensor,
...@@ -813,7 +851,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, ...@@ -813,7 +851,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
if hasattr(torch.ops._C, "permute_cols"): if hasattr(torch.ops._C, "permute_cols"):
@torch.library.register_fake("_C::permute_cols") @register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor, def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor: perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a) return torch.empty_like(a)
...@@ -959,37 +997,41 @@ def ggml_mul_mat_a8( ...@@ -959,37 +997,41 @@ def ggml_mul_mat_a8(
# mamba # mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor], query_start_loc: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor: has_initial_state: Optional[torch.Tensor],
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, silu_activation: bool, pad_slot_id: int):
initial_states_, final_states_out_, torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
silu_activation) query_start_loc, cache_indices,
has_initial_state, silu_activation,
pad_slot_id)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor, def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor],
bias_: Optional[torch.Tensor], silu_activation: bool,
silu_activation: bool, cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor: pad_slot_id: int):
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation, silu_activation, cache_seqlens,
conv_state_indices) conv_state_indices, pad_slot_id)
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor], delta_softplus: bool,
x: Optional[torch.Tensor]) -> List[torch.Tensor]: query_start_loc: Optional[torch.Tensor],
return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, cache_indices: Optional[torch.Tensor],
delta_bias_, delta_softplus, index_, has_initial_state: Optional[torch.Tensor],
x) ssm_states: torch.Tensor, pad_slot_id: int):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, query_start_loc,
cache_indices, has_initial_state,
ssm_states, pad_slot_id)
# moe # moe
...@@ -1011,16 +1053,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, ...@@ -1011,16 +1053,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
@torch.library.register_fake("_moe_C::marlin_gemm_moe") @register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor, sorted_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType, perm: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int, b_q_type: ScalarType, size_m: int, size_n: int,
is_k_full: bool, num_experts: int, topk: int, size_k: int, is_k_full: bool, num_experts: int,
moe_block_size: int, replicate_input: bool, topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor: apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n), return torch.empty((size_m, topk, size_n),
dtype=a.dtype, dtype=a.dtype,
......
...@@ -186,6 +186,9 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -186,6 +186,9 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# Max number of query tokens for among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[ _cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None "BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[ _cached_decode_metadata: Optional[
...@@ -357,6 +360,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -357,6 +360,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
...@@ -373,7 +378,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -373,7 +378,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -399,7 +404,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -399,7 +404,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
assert kv_cache is None \ assert kv_cache.numel() == 0 \
or prefill_meta.block_tables is None \ or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \ or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention." "Does not support prefix-enabled attention."
......
...@@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, ...@@ -13,152 +13,15 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.forward_context import get_forward_context
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
# yapf: disable from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.vllm_flash_attn import ( flash_attn_with_kvcache)
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)
# yapf: enable
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)
@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)
@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)
@torch.library.custom_op("vllm::reshape_and_cache_flash",
mutates_args=["kv_cache"])
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
k_scale, v_scale)
@reshape_and_cache_flash.register_fake # type: ignore
def _(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
pass
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -245,8 +108,12 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -245,8 +108,12 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch.
max_query_len: Optional[int] max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
...@@ -305,6 +172,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -305,6 +172,7 @@ class FlashAttentionMetadata(AttentionMetadata):
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len, max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0, max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1], query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
...@@ -331,20 +199,27 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -331,20 +199,27 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None, query_start_loc=self.query_start_loc[self.num_prefills:]
seq_start_loc=None, if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor], sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int): block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
...@@ -355,6 +230,23 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -355,6 +230,23 @@ class FlashAttentionMetadata(AttentionMetadata):
assert num_seqs > num_queries assert num_seqs > num_queries
assert self.use_cuda_graph assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0 assert self.num_prefills == 0
assert self.num_prefill_tokens == 0 assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs assert self.num_decode_tokens == num_seqs
...@@ -366,7 +258,6 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -366,7 +258,6 @@ class FlashAttentionMetadata(AttentionMetadata):
assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1 assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0 assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, ) assert self.query_start_loc.shape == (num_queries + 1, )
...@@ -414,8 +305,6 @@ class FlashAttentionMetadataBuilder( ...@@ -414,8 +305,6 @@ class FlashAttentionMetadataBuilder(
self.runner = input_builder.runner self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
...@@ -441,9 +330,6 @@ class FlashAttentionMetadataBuilder( ...@@ -441,9 +330,6 @@ class FlashAttentionMetadataBuilder(
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
else: else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len) self.curr_seq_lens.append(curr_seq_len)
...@@ -467,13 +353,37 @@ class FlashAttentionMetadataBuilder( ...@@ -467,13 +353,37 @@ class FlashAttentionMetadataBuilder(
# Compute slot mapping. # Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables) is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
is_prompt, query_len, context_len, self.sliding_window, context_len,
self.use_v2_block_manager) self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables) self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int): cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors. """Build attention metadata with on-device tensors.
...@@ -498,33 +408,22 @@ class FlashAttentionMetadataBuilder( ...@@ -498,33 +408,22 @@ class FlashAttentionMetadataBuilder(
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens) max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
num_seqs = len(seq_lens)
if use_captured_graph: if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
# The shape of graph_block_tables is num_seqs, self.block_tables)
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
else: else:
block_tables = make_tensor_with_pad( block_tables = make_tensor_with_pad(
self.block_tables, self.block_tables,
...@@ -566,6 +465,7 @@ class FlashAttentionMetadataBuilder( ...@@ -566,6 +465,7 @@ class FlashAttentionMetadataBuilder(
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
...@@ -665,6 +565,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -665,6 +565,8 @@ class FlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
...@@ -679,106 +581,198 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -679,106 +581,198 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
num_tokens, hidden_size = query.shape output = torch.ops.vllm.unified_flash_attention(
# Reshape the query, key, and value tensors. query,
query = query.view(-1, self.num_heads, self.head_size) key,
key = key.view(-1, self.num_kv_heads, self.head_size) value,
value = value.view(-1, self.num_kv_heads, self.head_size) self.num_heads,
self.head_size,
if kv_cache is not None: self.num_kv_heads,
key_cache = kv_cache[0] kv_cache,
value_cache = kv_cache[1] self.kv_cache_dtype,
k_scale,
# Reshape the input keys and values and store them in the cache. v_scale,
# If kv_cache is not provided, the new key and value tensors are self.scale,
# not cached. This happens during the initial memory profiling run. self.sliding_window,
torch.ops.vllm.reshape_and_cache_flash( self.alibi_slopes,
key, self.logits_soft_cap,
value, )
kv_cache,
attn_metadata.slot_mapping.flatten(), return output
self.kv_cache_dtype,
k_scale,
v_scale, @torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens if decode_meta := attn_metadata.decode_metadata:
num_decode_tokens = attn_metadata.num_decode_tokens # Decoding run.
assert key.shape[0] == num_prefill_tokens + num_decode_tokens # Use flash_attn_varlen_func kernel for speculative decoding
assert value.shape[0] == num_prefill_tokens + num_decode_tokens # because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# Query for decode. KV is not needed because it is already cached. if decode_meta.max_decode_query_len > 1:
decode_query = query[num_prefill_tokens:] decode_output = flash_attn_varlen_func(
# QKV for prefill. q=decode_query,
query = query[:num_prefill_tokens] k=key_cache,
key = key[:num_prefill_tokens] v=value_cache,
value = value[:num_prefill_tokens] cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
assert query.shape[0] == num_prefill_tokens cu_seqlens_k=decode_meta.seq_start_loc,
assert decode_query.shape[0] == num_decode_tokens max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
prefill_output: Optional[torch.Tensor] = None causal=True,
decode_output: Optional[torch.Tensor] = None alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
if prefill_meta := attn_metadata.prefill_metadata: block_table=decode_meta.block_tables,
# Prompt run. )
if (kv_cache is None or prefill_meta.block_tables is None else:
or prefill_meta.block_tables.numel() == 0): # Use flash_attn_with_kvcache for normal decoding.
# normal attention decode_output = flash_attn_with_kvcache(
# When block_tables are not filled, it means q and k are the q=decode_query.unsqueeze(1),
# prompt, and they have the same length. k_cache=key_cache,
prefill_output = torch.ops.vllm.flash_attn_varlen_func( v_cache=value_cache,
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale, softmax_scale=softmax_scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=alibi_slopes,
softcap=self.logits_soft_cap, softcap=logits_soft_cap,
).squeeze(1) ).squeeze(1)
if prefill_output is None: if prefill_output is None:
assert decode_output is not None assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size) return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None: if decode_output is None:
assert prefill_output is not None assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size) return prefill_output.view(num_prefill_tokens, hidden_size)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) # Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
...@@ -7,7 +7,7 @@ try: ...@@ -7,7 +7,7 @@ try:
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
import vllm.attention.backends.flash_attn # noqa from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError: except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None BatchDecodeWithPagedKVCacheWrapper = None
...@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -26,6 +26,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.forward_context import get_forward_context
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
...@@ -410,18 +411,22 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -410,18 +411,22 @@ class FlashInferMetadata(AttentionMetadata):
return self return self
def advance_step( def advance_step(self,
self, model_input: "ModelInputForGPUWithSamplingMetadata",
model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor],
sampled_token_ids: Optional[torch.Tensor], block_size: int,
block_size: int, num_seqs: int,
num_seqs: int, num_queries: int,
num_queries: int, turn_prefills_into_decodes: bool = False):
):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
assert num_seqs > 0 assert num_seqs > 0
assert num_queries > 0 assert num_queries > 0
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
...@@ -470,8 +475,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -470,8 +475,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields. # for the precise definition of the following fields.
...@@ -537,9 +540,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -537,9 +540,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
is_profile_run = is_block_tables_empty(block_tables) is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping. # Compute slot mapping.
start_idx = compute_slot_mapping_start_idx( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
is_prompt, query_len, context_len, self.sliding_window, context_len,
self.use_v2_block_manager) self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables) self.block_size, inter_data.block_tables)
...@@ -591,7 +594,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -591,7 +594,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device = self.runner.device device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1 use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens num_decode_tokens = self.num_decode_tokens
...@@ -630,7 +632,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -630,7 +632,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int, dtype=torch.int,
device=device, device=device,
) )
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
...@@ -746,7 +747,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -746,7 +747,7 @@ class FlashInferImpl(AttentionImpl):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
...@@ -759,73 +760,132 @@ class FlashInferImpl(AttentionImpl): ...@@ -759,73 +760,132 @@ class FlashInferImpl(AttentionImpl):
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"FlashInferImpl") "FlashInferImpl")
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if attn_metadata.num_prefill_tokens > 0: return torch.ops.vllm.unified_flash_infer(
assert attn_metadata.num_decode_tokens == 0, ( query,
"Chunked prefill is not supported with flashinfer yet.") key,
if attn_metadata.num_decode_tokens > 0: value,
assert attn_metadata.num_prefill_tokens == 0, ( self.num_heads,
"Chunked prefill is not supported with flashinfer yet.") self.head_size,
if kv_cache is not None: self.num_kv_heads,
# Use the same reshape and cache kernel as flash attention. kv_cache,
ops.reshape_and_cache_flash( self.kv_cache_dtype,
key, k_scale,
value, v_scale,
kv_cache[:, 0], self.scale,
kv_cache[:, 1], self.sliding_window,
attn_metadata.slot_mapping.flatten(), self.alibi_slopes,
self.kv_cache_dtype, self.logits_soft_cap,
k_scale, )
v_scale,
@torch.library.custom_op("vllm::unified_flash_infer",
mutates_args=["kv_cache"])
def unified_flash_infer(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
current_metadata = get_forward_context()
assert current_metadata is not None
assert isinstance(current_metadata, FlashInferMetadata)
attn_metadata: FlashInferMetadata = current_metadata
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if attn_metadata.num_prefill_tokens > 0:
assert attn_metadata.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous() # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
) )
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=self.logits_soft_cap,
causal=True)
else: else:
assert attn_metadata.decode_metadata is not None assert prefill_meta is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None assert prefill_meta.prefill_wrapper is not None
output = attn_metadata.decode_metadata.decode_wrapper.forward( output = prefill_meta.prefill_wrapper.forward(
query, query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True)
kv_cache, else:
sm_scale=self.scale, assert attn_metadata.decode_metadata is not None
logits_soft_cap=self.logits_soft_cap, assert attn_metadata.decode_metadata.decode_wrapper is not None
k_scale=k_scale, output = attn_metadata.decode_metadata.decode_wrapper.forward(
v_scale=v_scale) query,
return output.view(num_tokens, hidden_size) kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size)
@unified_flash_infer.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
...@@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -167,7 +167,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
...@@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -180,6 +180,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
...@@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -196,7 +198,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: if kv_cache.numel() > 0:
key_cache, value_cache = self.split_kv_cache( key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache( ipex_ops.reshape_and_cache(
...@@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -212,7 +214,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, value = value.repeat_interleave(self.num_queries_per_kv,
......
...@@ -9,6 +9,31 @@ from vllm.attention.backends.abstract import (AttentionBackend, ...@@ -9,6 +9,31 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
src_offset: int, dst_offset: int) -> None:
def create_roi_tensor(
tensor: ov.Tensor,
block_number: int,
) -> ov.Tensor:
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
roi_end = ov.runtime.Coordinate(tensor.get_shape())
roi_begin[0] = block_number
roi_end[0] = block_number + 1
if isinstance(tensor, ov.Tensor):
return ov.Tensor(tensor, roi_begin, roi_end)
else:
return ov.RemoteTensor(tensor, roi_begin, roi_end)
src_roi_tensor = \
create_roi_tensor(src_tensor, src_offset)
dst_roi_tensor = \
create_roi_tensor(dst_tensor, dst_offset)
src_roi_tensor.copy_to(dst_roi_tensor)
class OpenVINOAttentionBackend(AttentionBackend): class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -44,13 +69,12 @@ class OpenVINOAttentionBackend(AttentionBackend): ...@@ -44,13 +69,12 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: ov.Tensor, src_tensor: ov.Tensor,
dst_kv_cache: ov.Tensor, dst_tensor: ov.Tensor,
src_to_dst: torch.Tensor, src_to_dists: List[Tuple[int, int]],
) -> None: ) -> None:
# OpenVINO currently supports only CPU, which does not require for src, dst in src_to_dists:
# swap of KV cache blocks copy_cache_block(src_tensor, dst_tensor, src, dst)
raise NotImplementedError
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
...@@ -59,8 +83,8 @@ class OpenVINOAttentionBackend(AttentionBackend): ...@@ -59,8 +83,8 @@ class OpenVINOAttentionBackend(AttentionBackend):
) -> None: ) -> None:
for src, dst in src_to_dists: for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches: for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :] copy_cache_block(key_cache, key_cache, src, dst)
value_cache.data[dst, :] = value_cache.data[src, :] copy_cache_block(value_cache, value_cache, src, dst)
@dataclass @dataclass
......
...@@ -130,7 +130,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -130,7 +130,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert tpu_type is not None assert tpu_type is not None
tpu_type = tpu_type.lower() tpu_type = tpu_type.lower()
if "lite" not in tpu_type: if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0: if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head" self.megacore_mode = "kv_head"
else: else:
...@@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -143,7 +143,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
...@@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -155,8 +155,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [batch_size, seq_len, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size] kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size] kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
...@@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -173,7 +175,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
value = value.view(batch_size, seq_len, self.num_kv_heads, value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size) self.head_size)
if kv_cache[0] is not None: if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
...@@ -205,36 +207,55 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -205,36 +207,55 @@ class PallasAttentionBackendImpl(AttentionImpl):
output = output.permute(0, 2, 1, 3) output = output.permute(0, 2, 1, 3)
else: else:
# Decoding run. # Decoding run.
assert kv_cache is not None assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value. pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None assert attn_metadata.block_tables is not None
else: assert attn_metadata.context_lens is not None
megacore_mode = self.megacore_mode # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# NOTE(woosuk): A temporary workaround to avoid the error: # the kernel compilation will fail. To avoid this, we split the
# "xla::paged_attention() Expected a value of type 'str' for # batch dimension into smaller chunks and run the kernel multiple
# argument 'megacore_mode' but instead found type 'NoneType'." # times.
if megacore_mode is not None: MAX_SMEM_USAGE = 512 * 1024
output = torch.ops.xla.paged_attention( size_per_seq = 4 * attn_metadata.block_tables.shape[1]
query.squeeze(dim=1), max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.context_lens, attn_metadata.context_lens,
attn_metadata.block_tables, attn_metadata.block_tables,
pages_per_compute_block, pages_per_compute_block,
megacore_mode=megacore_mode, self.megacore_mode,
) )
else: else:
output = torch.ops.xla.paged_attention( chunk_size = max_num_seq
query.squeeze(dim=1), # Make sure the chunk size is a multiple of 2.
key_cache, chunk_size = chunk_size // 2 * 2
value_cache, num_chunks = (batch_size + chunk_size - 1) // chunk_size
attn_metadata.context_lens,
attn_metadata.block_tables, output = torch.empty_like(query)
pages_per_compute_block, for chunk_idx in range(num_chunks):
) chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor. # Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size) return output.reshape(batch_size, seq_len, hidden_size)
...@@ -256,3 +277,43 @@ def write_to_kv_cache( ...@@ -256,3 +277,43 @@ def write_to_kv_cache(
value_cache = value_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key) key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value) value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""
@staticmethod
def get_name() -> str:
return "placeholder-attn"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl
@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder
@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return
@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.seq_start_loc is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
num_decode_tokens = batch_size
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class PlaceholderAttentionImpl(AttentionImpl):
def __init__(self, *args, **kwargs) -> None:
return
def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
...@@ -116,9 +116,14 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -116,9 +116,14 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only. # Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed # (batch_size,) A tensor of context lengths (tokens that are computed
# so far). # so far).
context_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor]
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
...@@ -183,12 +188,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -183,12 +188,22 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor], sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int): block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
""" """
Update metadata in-place to advance one decode step. Update metadata in-place to advance one decode step.
""" """
assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with rocm_flash_attn yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
# When using cudagraph, the num_seqs is padded to the next captured # When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in # batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries # the batch. For --enforce-eager mode, num_seqs == num_queries
...@@ -398,10 +413,14 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -398,10 +413,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
...@@ -414,7 +433,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -414,7 +433,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -451,7 +470,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -451,7 +470,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
......
...@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -75,6 +75,22 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
seq_lens: Optional[List[int]] seq_lens: Optional[List[int]]
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt # It is a list because it is needed to set per prompt
...@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -82,6 +98,28 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
# from xformer API. # from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None self.attn_bias: Optional[List[torch.Tensor]] = None
self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
self.cross_attn_bias: Optional[List[torch.Tensor]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
@property @property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
...@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -101,6 +139,136 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
return self return self
def get_seq_lens(
self,
attn_type: AttentionType,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if attn_type == AttentionType.DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: AttentionType,
) -> Optional[List[torch.Tensor]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: List[torch.Tensor],
attn_type: AttentionType,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
self.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...@@ -151,7 +319,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -151,7 +319,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
...@@ -164,88 +332,108 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -164,88 +332,108 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if (attn_type == AttentionType.ENCODER
raise NotImplementedError("Encoder self-attention and " and (not attn_metadata.is_all_encoder_attn_metadata_set)):
"encoder/decoder cross-attention " raise AttributeError("Encoder attention requires setting "
"are not implemented for " "encoder metadata attributes.")
"TorchSDPABackendImpl") elif (attn_type == AttentionType.ENCODER_DECODER
num_tokens, hidden_size = query.shape and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) if key is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, k_scale,
v_scale)
if attn_metadata.is_prompt: if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if (kv_cache.numel() == 0
if self.num_kv_heads != self.num_heads: or prefill_meta.block_tables.numel() == 0):
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) output = self._run_sdpa_forward(query,
value = value.repeat_interleave(self.num_queries_per_kv, key,
dim=1) value,
prefill_meta,
if attn_metadata.attn_bias is None: attn_type=attn_type)
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else: else:
# prefix-enabled attention # prefix-enabled attention
raise RuntimeError( raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.") "Torch SDPA backend doesn't support prefix decoding.")
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode( output = PagedAttention.forward_decode(
query, query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, block_tables_arg,
attn_metadata.seq_lens_tensor, seq_lens_arg,
attn_metadata.max_decode_seq_len, max_seq_len_arg,
self.kv_cache_dtype, self.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
...@@ -257,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -257,6 +445,59 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
......
...@@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): ...@@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int, context_len: int, sliding_window: int):
use_v2_block_manager: bool):
""" """
Compute the start index of slot mapping. Compute the start index of slot mapping.
""" """
start_idx = 0 start_idx = 0
if is_prompt and sliding_window is not None: if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window) start_idx = max(0, query_len - sliding_window)
return start_idx return start_idx
...@@ -138,8 +132,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -138,8 +132,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.sliding_window = input_builder.sliding_window self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group( def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
...@@ -180,9 +172,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -180,9 +172,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# Compute slot mapping. # Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables) is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
is_prompt, query_len, context_len, self.sliding_window, context_len,
self.use_v2_block_manager) self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx, seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables) self.block_size, inter_data.block_tables)
...@@ -312,7 +304,8 @@ class CommonAttentionState(AttentionState): ...@@ -312,7 +304,8 @@ class CommonAttentionState(AttentionState):
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None, max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture, max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None, query_start_loc=None,
......
...@@ -118,6 +118,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -118,6 +118,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in # (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length # the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10]. # is [4, 6], it is [0, 4, 10].
...@@ -445,7 +448,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -445,7 +448,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor], key: Optional[torch.Tensor],
value: Optional[torch.Tensor], value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
...@@ -489,6 +492,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -489,6 +492,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention, attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross- decoder self-attention, or encoder/decoder cross-
...@@ -522,7 +527,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -522,7 +527,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# which KV cache memory-mapping & which # which KV cache memory-mapping & which
# seqlen datastructures we utilize # seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache is not None): if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or # KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not # encoder-decoder-cross-attention, but not
# during encoder attention. # during encoder attention.
...@@ -554,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -554,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, v_scale) k_scale, v_scale)
if attn_type != AttentionType.ENCODER: if attn_type == AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable; # Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them # derive token-count from query shape & and treat them
# as 100% prefill tokens # as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0 num_decode_tokens = 0
elif attn_type == AttentionType.DECODER:
if attn_type == AttentionType.DECODER: # Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder # Only enforce this shape-constraint for decoder
# self-attention # self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
output = torch.empty_like(query) output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
...@@ -580,15 +592,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -580,15 +592,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_tokens]
if key is not None and value is not None: if key is not None and value is not None:
key = key[:num_prefill_tokens] key = key[:num_encoder_tokens]
value = value[:num_prefill_tokens] value = value[:num_encoder_tokens]
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# normal attention. # normal attention.
# block tables are empty if the prompt does not have a cached # block tables are empty if the prompt does not have a cached
# prefix. # prefix.
......
...@@ -42,10 +42,12 @@ class Attention(nn.Module): ...@@ -42,10 +42,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size block_size = cache_config.block_size
sliding_window = cache_config.sliding_window sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 block_size = 16
sliding_window = None sliding_window = None
is_attention_free = False
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
...@@ -76,9 +78,9 @@ class Attention(nn.Module): ...@@ -76,9 +78,9 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, attn_backend = get_attn_backend(head_size, sliding_window, dtype,
sliding_window, dtype, kv_cache_dtype, kv_cache_dtype, block_size,
block_size, blocksparse_params is_attention_free, blocksparse_params
is not None) is not None)
impl_cls = attn_backend.get_impl_cls() impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
...@@ -90,7 +92,7 @@ class Attention(nn.Module): ...@@ -90,7 +92,7 @@ class Attention(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -24,6 +24,7 @@ class _Backend(enum.Enum): ...@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
PALLAS = enum.auto() PALLAS = enum.auto()
IPEX = enum.auto() IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend: def backend_name_to_enum(backend_name: str) -> _Backend:
...@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: ...@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_attn_backend( def get_attn_backend(
num_heads: int,
head_size: int, head_size: int,
num_kv_heads: int,
sliding_window: Optional[int], sliding_window: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: Optional[str], kv_cache_dtype: Optional[str],
block_size: int, block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False, is_blocksparse: bool = False,
) -> Type[AttentionBackend]: ) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
...@@ -105,9 +105,8 @@ def get_attn_backend( ...@@ -105,9 +105,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend) BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(num_heads, head_size, num_kv_heads, backend = which_attn_to_use(head_size, sliding_window, dtype,
sliding_window, dtype, kv_cache_dtype, kv_cache_dtype, block_size, is_attention_free)
block_size)
if backend == _Backend.FLASH_ATTN: if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend)
...@@ -146,23 +145,31 @@ def get_attn_backend( ...@@ -146,23 +145,31 @@ def get_attn_backend(
logger.info("Using Pallas backend.") logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend return PallasAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else: else:
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
def which_attn_to_use( def which_attn_to_use(
num_heads: int,
head_size: int, head_size: int,
num_kv_heads: int,
sliding_window: Optional[int], sliding_window: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: Optional[str], kv_cache_dtype: Optional[str],
block_size: int, block_size: int,
is_attention_free: bool,
) -> _Backend: ) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
# Check whether a particular choice of backend was # Check whether a particular choice of backend was
# previously forced. # previously forced.
# #
......
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class BeamSearchSequence:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens: List[int]
cum_logprob: float = 0.0
text: Optional[str] = None
@dataclass
class BeamSearchOutput:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences: List[BeamSearchSequence]
class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
]
self.completed: List[BeamSearchSequence] = []
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
length_penalty)
return sort_beams_key
...@@ -12,11 +12,11 @@ from tqdm import tqdm ...@@ -12,11 +12,11 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.inputs import PromptInputs
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.utils import FlexibleArgumentParser, merge_async_iterators
...@@ -75,7 +75,6 @@ def run_vllm( ...@@ -75,7 +75,6 @@ def run_vllm(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
...@@ -89,11 +88,9 @@ def run_vllm( ...@@ -89,11 +88,9 @@ def run_vllm(
distributed_executor_backend: Optional[str], distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -117,7 +114,6 @@ def run_vllm( ...@@ -117,7 +114,6 @@ def run_vllm(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
load_format=load_format, load_format=load_format,
num_scheduler_steps=num_scheduler_steps, num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
) )
...@@ -129,13 +125,12 @@ def run_vllm( ...@@ -129,13 +125,12 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
# warmup # warmup
warmup_prompts = [] warmup_prompts = []
warmup_sampling_params = [] warmup_sampling_params = []
...@@ -144,9 +139,8 @@ def run_vllm( ...@@ -144,9 +139,8 @@ def run_vllm(
warmup_sampling_params.append( warmup_sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
...@@ -158,7 +152,7 @@ def run_vllm( ...@@ -158,7 +152,7 @@ def run_vllm(
# dummy_prompt_token_ids = np.random.randint(10000, # dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.num_prompts, # size=(args.num_prompts,
# args.input_len)) # args.input_len))
# dummy_inputs: List[PromptInputs] = [{ # dummy_prompts: List[PromptType] = [{
# "prompt_token_ids": batch # "prompt_token_ids": batch
# } for batch in dummy_prompt_token_ids.tolist()] # } for batch in dummy_prompt_token_ids.tolist()]
...@@ -171,22 +165,27 @@ def run_vllm( ...@@ -171,22 +165,27 @@ def run_vllm(
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion() # run_to_completion()
if not use_new_beam_search_impl:
use_beam_search = False
if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests] prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0][2] output_len = requests[0][2]
for prompt, input_len, _output_len in requests: for prompt, input_len, _output_len in requests:
assert _output_len == output_len assert _output_len == output_len
start = time.perf_counter() start = time.perf_counter()
llm.beam_search(prompts, llm.beam_search(
beam_width=n, prompts,
max_tokens=output_len, BeamSearchParams(
ignore_eos=True) beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start
...@@ -199,7 +198,6 @@ async def run_vllm_async( ...@@ -199,7 +198,6 @@ async def run_vllm_async(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
...@@ -213,7 +211,6 @@ async def run_vllm_async( ...@@ -213,7 +211,6 @@ async def run_vllm_async(
distributed_executor_backend: Optional[str], distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
...@@ -241,7 +238,6 @@ async def run_vllm_async( ...@@ -241,7 +238,6 @@ async def run_vllm_async(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
load_format=load_format, load_format=load_format,
num_scheduler_steps=num_scheduler_steps, num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False, worker_use_ray=False,
disable_log_requests=True, disable_log_requests=True,
...@@ -258,9 +254,8 @@ async def run_vllm_async( ...@@ -258,9 +254,8 @@ async def run_vllm_async(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
...@@ -282,11 +277,9 @@ def run_hf( ...@@ -282,11 +277,9 @@ def run_hf(
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
...@@ -318,7 +311,7 @@ def run_hf( ...@@ -318,7 +311,7 @@ def run_hf(
padding=True).input_ids padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=True,
num_return_sequences=n, num_return_sequences=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
...@@ -378,40 +371,37 @@ def main(args: argparse.Namespace): ...@@ -378,40 +371,37 @@ def main(args: argparse.Namespace):
if args.async_engine: if args.async_engine:
run_args = [ run_args = [
requests, args.model, args.tokenizer, args.quantization, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend, args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps, args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format, args.download_dir, args.load_format, args.disable_async_output_proc
args.disable_async_output_proc
] ]
else: else:
run_args = [ run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization, warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend, args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps, args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format, args.download_dir, args.load_format, args.disable_async_output_proc
args.disable_async_output_proc
] ]
if args.async_engine: if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing) run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args)) elapsed_time = uvloop.run(run_vllm_async(*run_args))
else: else:
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elapsed_time = run_vllm(*run_args)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.hf_max_batch_size, args.trust_remote_code)
args.trust_remote_code)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len) args.output_len)
...@@ -473,12 +463,10 @@ if __name__ == "__main__": ...@@ -473,12 +463,10 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument('--num-iters-warmup', parser.add_argument('--num-iters-warmup',
type=int, type=int,
default=1, default=1,
help='Number of iterations to run for warmup.') help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
...@@ -543,9 +531,6 @@ if __name__ == "__main__": ...@@ -543,9 +531,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Maximum number of forward steps per scheduler call.") help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--use-v2-block-manager",
action='store_true',
help="Enable block manager v2.")
parser.add_argument( parser.add_argument(
"--enable-prefix-caching", "--enable-prefix-caching",
action='store_true', action='store_true',
...@@ -633,8 +618,6 @@ if __name__ == "__main__": ...@@ -633,8 +618,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.") raise ValueError("dtype must be auto for MII backend.")
if args.n != 1: if args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
......
__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment