Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
#!/bin/bash
# Script to install DeepGEMM from source
# This script can be used both in Docker builds and by users locally
set -e
# Default values
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --ref requires an argument." >&2
exit 1
fi
DEEPGEMM_GIT_REF="$2"
shift 2
;;
--cuda-version)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --cuda-version requires an argument." >&2
exit 1
fi
CUDA_VERSION="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1" >&2
exit 1
;;
esac
done
# Auto-detect CUDA version if not provided
if [ -z "$CUDA_VERSION" ]; then
if command -v nvcc >/dev/null 2>&1; then
CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p')
echo "Auto-detected CUDA version: $CUDA_VERSION"
else
echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version"
exit 1
fi
fi
# Extract major and minor version numbers
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
# Check CUDA version requirement
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
exit 0
fi
echo "Installing DeepGEMM from source..."
echo "Repository: $DEEPGEMM_GIT_REPO"
echo "Reference: $DEEPGEMM_GIT_REF"
# Create a temporary directory for the build
INSTALL_DIR=$(mktemp -d)
trap 'rm -rf "$INSTALL_DIR"' EXIT
# Clone the repository
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
echo "🏗️ Building DeepGEMM"
pushd "$INSTALL_DIR/deepgemm"
# Checkout the specific reference
git checkout "$DEEPGEMM_GIT_REF"
# Build DeepGEMM
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
rm -rf build dist
rm -rf *.egg-info
python3 setup.py bdist_wheel
# Install the wheel
if command -v uv >/dev/null 2>&1; then
echo "Installing DeepGEMM wheel using uv..."
# Use --system in Docker contexts, respect user's environment otherwise
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
uv pip install --system dist/*.whl
else
uv pip install dist/*.whl
fi
else
echo "Installing DeepGEMM wheel using pip..."
python3 -m pip install dist/*.whl
fi
popd
echo "✅ DeepGEMM installation completed successfully"
\ No newline at end of file
......@@ -36,8 +36,7 @@ profiling and analyzing nsys profile output.
## Notes
- Make sure you have pandas installed.
- Make sure nsys is installed, and specify the path to the `nsys` command with
`--nsys_cmd` if it is not in your PATH.
- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is installed, and specify the path to the `nsys` command with `--nsys_cmd` if it is not in your PATH.
- For more details on available engines and models, see the help string in
the script or run:
......@@ -135,34 +134,31 @@ time which would cause a difference for the overall category.
## Example 3: add new classification for a new model
Suppose there's a new model ABC that is available for engine DEF, and say there
are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels
To create a new engine DEF with model ABC, just add another json file in the same directory as
gputrc2graph.py with the same format as the other json files. The script will automatically pick up all the json files in the same directory as engine/model specifications.
Then, for this new model, suppose there are 4 kernels to be classified into "gemm" and "attn", where the gemm kernels
have names with "*H*" or "*I*" in them, and attn kernels have names with "*J*"
or "*K*" in them, add a new entry like so:
```python
engine_model = {
'DEF': {
'ABC': {
'layer_anno': {
'Stage': {
'.*': 'layer',
},
'Substage': {
'H|I': 'gemm',
'J|K': 'attn',
'CUDA mem': 'non-gpu-H_D_memops',
'.*': 'misc'
}
}
},
}
'vllm': {...}
or "*K*" in them, just add another .json file in the same directory as
gputrc2graph.py with the same format as the other json files, like the following:
```json
{
"DEF": {
"ABC": {
"H|I": "gemm",
"J|K": "attn",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
```
Basically Substage is a dictionary with a list of key/value pairs, where the
keys are regex's of the kernel names to be classified, and values are the
classification bins which one wishes to compare across engines/models.
Each entry in the dictionary consists of:
- key: a regex used to classify the kernels
- value: the category to classify the kernels into.
The last 2 entries are common for all engine/models, consisting of CUDA memory
operations and a 'misc' for anything that's leftover and can't be classified.
......@@ -173,3 +169,6 @@ like the following:
```bash
--infile new.nsys-rep,DEF,ABC,<runtime>
```
If the engine_DEF.json file already exists, just add the model as a new node in
the existing engine file, after the other models.
......@@ -15,132 +15,18 @@ logger = logging.getLogger(__name__)
# helper data class for annotating kernels
class EngineModelData:
# engine + model mappings
engine_model = {
'vllm': {
'llama': {
'layer_anno': {
'Stage': {
'.*': 'layer',
},
'Substage': {
'gemm': 'gemm',
'fused_moe_kernel|GroupProblemShape|group_gemm_starts':
'moe_gemm', #llama4
'moe|sigmoid': 'moe', #llama4
'CatArrayBatched|prepare_inputs': 'prepare_next',
'flash': 'attn',
'ncclDevKernel|cross_device_reduce':
'nccl_and_custom_ar',
'_norm_': 'norm',
'act_and_mul_': 'silu',
'rotary_embedding_kernel': 'rope',
'SoftMax': 'softmax',
'elementwise': 'elementwise',
'fp8_quant': 'quantize',
'reduce_kernel': 'reduce',
'triton': 'triton_kernel',
'CUDA mem': 'non-gpu-H_D_memops',
'.*': 'misc'
}
}
},
'ds': {
'layer_anno': {
'Stage': {
'.*': 'layer',
},
'Substage': {
'block_fp8|gemm_fp8_blockwise':
'block_fp8_gemm',
'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal':
'moe_gemm',
'gemm|matmul|nvjet':
'gemm',
'moe|sigmoid|expert':
'moe',
'_fwd_|FlashAttn|_mla_|_attn_':
'attn',
'CatArrayBatched':
'prepare_next',
'ncclDevKernel|cross_device_reduce':
'nccl_and_custom_ar',
'Norm|_norm_':
'norm',
'sbtopk':
'topk',
'act_and_mul_':
'activation',
'compute_position_kernel':
'rope',
'elementwise':
'elementwise',
'fp8_quant|quant_fp8|cvt_fp16_to_fp4':
'quantize',
'reduce':
'reduce',
'SoftMax':
'softmax',
'triton':
'triton_kernel',
'CUDA mem':
'non-gpu-H_D_memops',
'.*':
'misc'
}
}
},
'gpt-oss': {
'layer_anno': {
'Stage': {
'.*': 'layer',
},
'Substage': {
'block_fp8|gemm_fp8_blockwise':
'block_fp8_gemm',
'fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_'
# this section is triton_moe_gemm
'|matmul_ogs_|_topk_forward|_combined_routing'
'|_sum_bitmatrix_rows|_compute_writeback_idx':
'moe_gemm',
'gemm|matmul|nvjet':
'gemm',
'moe|sigmoid|expert|splitKreduce':
'moe',
'_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha':
'attn',
'CatArrayBatched':
'prepare_next',
'ncclDevKernel|cross_device_reduce':
'nccl_and_custom_ar',
'Norm|_norm_':
'norm',
'sbtopk':
'topk',
'act_and_mul_':
'activation',
'compute_position_kernel':
'rope',
'elementwise':
'elementwise',
'fp8_quant|quant_fp8|cvt_fp16_to_fp4|quantize':
'quantize',
'reduce':
'reduce',
'SoftMax':
'softmax',
'triton':
'triton_kernel',
'CUDA mem':
'non-gpu-H_D_memops',
'.*':
'misc'
}
}
}
},
}
def load_engine_model():
""" returns engine_model built from all json files in the current dir """
import glob
import json
engine_model = {}
json_files = glob.glob(
os.path.join(os.path.dirname(__file__) or ".", "*.json"))
for fname in json_files:
with open(fname, encoding="utf-8") as f:
engine_model.update(json.load(f))
return engine_model
class GPUTrace2Graph:
......@@ -148,8 +34,7 @@ class GPUTrace2Graph:
Parses output of nsys report, generates csv and bar chart output
"""
def __init__(self, nsys_cmd):
self.nsys_cmd = nsys_cmd
def __init__(self):
import pandas as pd # avoid importing till needed
self.pd = pd
self.pd.options.mode.copy_on_write = True
......@@ -227,7 +112,7 @@ class GPUTrace2Graph:
title = 'Model_Engine'
x = 'Model_Engine'
y = 'Elapsed Time (sec)'
color = 'Substage'
color = 'Category'
""" generate kernel mapping table """
# Sort Model_Engine categories by last field after underscore
df['Model_Engine'] = self.pd.Categorical(
......@@ -249,14 +134,13 @@ class GPUTrace2Graph:
Generate data table with columns per Model_Engine into result.html
"""
pivot_df = df.pivot_table(values='Elapsed Time (sec)',
index='Substage',
index='Category',
columns='Model_Engine',
aggfunc='sum',
observed=False).round(2)
# Add sum row at bottom
pivot_df.loc['total_elapsed_sec'] = pivot_df.sum()
pivot_df.fillna('').to_html('temp.html')
print('got')
with (open(f'{output_name}.html', 'a', encoding='utf-8') as
outfile, open('temp.html', encoding='utf-8') as infile):
outfile.write(infile.read())
......@@ -264,23 +148,22 @@ class GPUTrace2Graph:
print(f'Finished generating: \n'
f' {output_name}.html for stack bar chart \n'
f' {output_name}.csv for Kernel-Substage mapping')
f' {output_name}.csv for Kernel-Category mapping')
def anno_gpu_kernname(self, df, mapping):
""" add "stage" and "substage" columns """
""" add "Category" column """
def anno_gpu_kernname_helper(name, stage):
for kern_name, val in mapping['layer_anno'][stage].items():
def anno_gpu_kernname_helper(name):
for kern_name, val in mapping.items():
if re.search(kern_name, name):
return val
for stage in ['Stage', 'Substage']:
df[stage] = df['Name'].apply(anno_gpu_kernname_helper, stage=stage)
df['Category'] = df['Name'].apply(anno_gpu_kernname_helper)
def make_nongpu_row(self, df, nongpu_sec):
""" this will append non-gpu time entry at end of df """
nongpu_row = self.pd.DataFrame([df.iloc[-1]])
nongpu_row['Substage'] = nongpu_row['Name'] = 'CPU(non-GPU)'
nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)'
nongpu_row['Instances'] = 1
nongpu_row['Elapsed Time (sec)'] = nongpu_sec
return (nongpu_row)
......@@ -302,7 +185,7 @@ class GPUTrace2Graph:
logger.info('generating %s', new_file)
return True
def gen_sum_file(self, file):
def gen_sum_file(self, file, nsys_cmd):
"""
generates sum file from nsys trace with times per kernel and
returns the name of the sum file
......@@ -318,17 +201,21 @@ class GPUTrace2Graph:
sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv'
if self.should_gen_file(nsys_stats_file, file):
cmd = [
self.nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o',
nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o',
f'{file_dir}/{file_name}'
]
cmd_str = ' '.join(cmd)
logger.info('+ %s', cmd_str)
# estimate time based on calibrated 240M/min
file_size_mb = os.path.getsize(file) / 1e6
logger.info(
'nsys stats for %.2f MB file expected to take %.2f min',
file_size_mb, file_size_mb / 240)
try:
subprocess.run(cmd)
subprocess.run(cmd, check=True)
except Exception:
logger.error(
"%s failed, specify --nsys_cmd for correct nsys path",
cmd_str)
logger.error("%s failed; Use --nsys_cmd to specify nsys path",
cmd_str)
exit(1)
logger.info('generating non-overalapped sum %s', sum_file)
self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file)
......@@ -336,7 +223,7 @@ class GPUTrace2Graph:
logger.info('Finished generating %s', sum_file)
return sum_file
def gen_graph(self, in_file, out_dir, title):
def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model):
""" generates graph and csv file from in_file into out_dir """
# Initialize an empty DataFrame to store combined data
combined_df = self.pd.DataFrame()
......@@ -345,17 +232,16 @@ class GPUTrace2Graph:
file_name = os.path.basename(file)
if not file_dir:
file_dir = '.'
sum_file = self.gen_sum_file(file)
sum_file = self.gen_sum_file(file, nsys_cmd)
# read kernel summary file
df = self.pd.read_csv(sum_file)
# annotate kernel to their categories
assert EngineModelData.engine_model.get(engine)
assert EngineModelData.engine_model[engine].get(model)
assert engine_model.get(engine), f'engine {engine} unknown'
assert engine_model[engine].get(model), f'model {model} unknown'
# remove nsys-rep from file_name for shorter x-label
file_name = file_name.replace('.nsys-rep', '')
df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}'
self.anno_gpu_kernname(df,
EngineModelData.engine_model[engine][model])
self.anno_gpu_kernname(df, engine_model[engine][model])
# patch in non-gpu time
gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1)
total_sec = round(float(total_sec), 1)
......@@ -393,12 +279,12 @@ def main():
"--out_dir results/ --title \"Model=gpt-oss vLLM chart\""),
formatter_class=argparse.RawDescriptionHelpFormatter)
# Build help string showing available engine/model combinations
engine_model_help = []
for engine, models in EngineModelData.engine_model.items():
model_list = list(models.keys())
engine_model_help.append(f"{engine}:[{','.join(model_list)}]")
engine_model_str = ' '.join(engine_model_help)
# load supported engine_model
engine_model_supported = load_engine_model()
# Get a string representation of supported engine/model combinations
engine_model_supported_str = ', '.join(
f"{engine}:[{', '.join(models.keys())}]"
for engine, models in engine_model_supported.items())
parser.add_argument(
'--in_file',
type=parse_tuple,
......@@ -408,7 +294,7 @@ def main():
'separated by space. Elapsed_nonprofiled_sec is runtime without '
'profiling used to calculate non-gpu time. Specify 0 to use '
'elapsed time from nsys-rep but that might inflate non-gpu time. '
f'Available engine:[model] are: {engine_model_str} '
f'Available engine:[model] are: {engine_model_supported_str} '
f'Example: --infile d1.nsys-rep,vllm,llama,100 '
'd2.nsys-rep,vllm,gpt-oss,102'),
required=True)
......@@ -418,8 +304,9 @@ def main():
help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'),
default="nsys")
args = parser.parse_args()
gputrace = GPUTrace2Graph(args.nsys_cmd)
gputrace.gen_graph(args.in_file, args.out_dir, args.title)
gputrace = GPUTrace2Graph()
gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd,
engine_model_supported)
if __name__ == '__main__':
......
{
"vllm": {
"llama": {
"fused_moe_kernel|GroupProblemShape|group_gemm_starts|bmm_|GemmUniversal": "moe_gemm",
"gemm|nvjet": "gemm",
"moe|sigmoid": "moe",
"CatArrayBatched|prepare_inputs": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"_norm_|Norm": "norm",
"act_and_mul_": "activation",
"Rotary": "rope",
"SoftMax": "softmax",
"flash|fmha": "attn",
"elementwise": "elementwise",
"fp8_quant|cvt_": "quantize",
"reduce_kernel": "reduce",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"ds": {
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_": "moe_gemm",
"gemm|matmul|nvjet": "gemm",
"moe|sigmoid|expert": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"sbtopk": "topk",
"act_and_mul_": "activation",
"compute_position_kernel": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|cvt_": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"_fwd_|FlashAttn|_mla_|_attn_|fmha": "attn",
"triton": "triton_kernel",
"topk": "topk",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
},
"gpt-oss": {
"block_fp8|gemm_fp8_blockwise": "block_fp8_gemm",
"fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm",
"gemm|matmul|nvjet": "gemm",
"moe|sigmoid|expert|splitKreduce": "moe",
"CatArrayBatched": "prepare_next",
"ncclDevKernel|cross_device_reduce": "nccl_and_custom_ar",
"Norm|_norm_": "norm",
"topk": "topk",
"act_and_mul_": "activation",
"compute_position_kernel": "rope",
"elementwise": "elementwise",
"fp8_quant|quant_fp8|cvt_|quantize": "quantize",
"reduce": "reduce",
"SoftMax": "softmax",
"_fwd_|FlashAttn|_mla_|_attn_|_flash_|flash::prepare_varlen|fmha": "attn",
"triton": "triton_kernel",
"CUDA mem": "non-gpu-H_D_memops",
".*": "misc"
}
}
}
\ No newline at end of file
......@@ -392,14 +392,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
......@@ -442,25 +434,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: torch.SymInt,
......@@ -506,6 +479,30 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@register_fake("_C::cutlass_w4a8_mm")
def cutlass_w4a8_mm_fake(
a: torch.Tensor,
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
b_q: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
b_channel_scales: torch.Tensor,
a_token_scales: torch.Tensor,
out_type: Optional[torch.dtype] = None,
maybe_schedule: Optional[str] = None) -> torch.Tensor:
m = a.size(0)
n = b_q.size(1)
out_dtype = out_type if out_type is not None else torch.bfloat16
return torch.empty((m, n), device=a.device, dtype=out_dtype)
@register_fake("_C::cutlass_pack_scale_fp8")
def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
return torch.empty_like(scales, memory_format=torch.contiguous_format)
@register_fake("_C::cutlass_encode_and_reorder_int4b")
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
......@@ -849,6 +846,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets)
def get_cutlass_moe_mm_problem_sizes(
topk_ids: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (token→expert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
......@@ -1042,6 +1061,30 @@ def machete_prepack_B(
group_scales_type)
# CUTLASS W4A8
def cutlass_w4a8_mm(
a: torch.Tensor,
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
b_q: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
b_channel_scales: torch.Tensor,
a_token_scales: torch.Tensor,
out_type: Optional[torch.dtype] = None,
maybe_schedule: Optional[str] = None) -> torch.Tensor:
return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size,
b_channel_scales, a_token_scales,
out_type, maybe_schedule)
def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_pack_scale_fp8(scales)
def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
if hasattr(torch.ops._C, "permute_cols"):
@register_fake("_C::permute_cols")
......@@ -1331,15 +1374,6 @@ def scaled_int8_quant(
return output, input_scales, input_azp
# qqq ops
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
dtype: Optional[torch.dtype]) -> torch.Tensor:
......@@ -1473,6 +1507,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
gating_output)
def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor,
num_expert_group: int, topk_group: int, topk: int,
renormalize: bool, routed_scaling_factor: float):
if not current_platform.is_cuda():
raise NotImplementedError("The fused grouped_topk kernel is only "
"available on CUDA platforms")
return torch.ops._moe_C.grouped_topk(scores, scores_with_bias,
num_expert_group, topk_group, topk,
renormalize, routed_scaling_factor)
def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor],
b_qweight: torch.Tensor,
b_bias: Optional[torch.Tensor],
......@@ -1585,6 +1630,20 @@ def concat_and_cache_mla(
scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......@@ -1608,14 +1667,28 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def gather_and_maybe_dequant_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
kv_cache_dtype: str,
scale: torch.Tensor,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache(
src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, seq_starts)
def cp_gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int:
......@@ -1846,3 +1919,86 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
M = mat1.size(0)
N = mat2.size(0)
return torch.empty((M, N), dtype=out_dtype)
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler: Optional[int] = None
self.n = -1
self.k = -1
def __del__(self):
if self.handler is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
def create_onednn_scaled_mm(
weight: torch.Tensor, # [K, N]
weight_scales: torch.Tensor,
output_type: torch.dtype,
dynamic_quant: bool,
use_azp: bool,
primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
weight, weight_scales, output_type, dynamic_quant, use_azp,
primitive_cache_size)
return handler
def onednn_scaled_int8_quant(input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True):
"""
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
Args:
input: The input tensor to be quantized to int8.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
azp: Optional zero-point for the int8 quantization.
Must be provided for asymmetric quantization if `scale` is provided.
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
token_num = input.numel() // input.shape[-1]
input = input.view((token_num, input.shape[-1]))
if scale is not None:
# static-per-tensor quantization.
assert symmetric == (
azp
is None), "azp must only be provided for asymmetric quantization."
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, azp
# dynamic-per-token quantization.
input_scales = torch.empty((token_num, 1),
device=input.device,
dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
input_azp)
return output, input_scales, input_azp
def onednn_scaled_mm(
dnnl_handler: CPUDNNLGEMMHandler,
x: torch.Tensor,
output: torch.Tensor,
input_scale: Optional[torch.Tensor],
input_zp: Optional[torch.Tensor],
input_zp_adj: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp,
input_zp_adj, bias, dnnl_handler.handler)
return output
......@@ -11,7 +11,7 @@ from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal["stop_sign", "cherry_blossom"]
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato"]
@dataclass(frozen=True)
......
......@@ -14,7 +14,6 @@ __all__ = [
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
......@@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
......@@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
......@@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......
......@@ -800,23 +800,33 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
layer: Attention layer instance.
q: Query tensor with shape = [num_tokens, num_heads, head_size]
k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Output tensor with shape [num_tokens, num_heads, head_size]
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]
......
......@@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
......@@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
......
......@@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -604,7 +605,8 @@ class FlashAttentionImpl(AttentionImpl):
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
......@@ -615,7 +617,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
......@@ -849,7 +851,7 @@ class FlashAttentionImpl(AttentionImpl):
def _get_query_key_seq_metadata(
attn_metadata,
attn_metadata: FlashAttentionMetadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
trtllm_batch_decode_with_kv_cache)
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
# Avoid turning these types into variables during type checking
if not TYPE_CHECKING:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
trtllm_batch_decode_with_kv_cache = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
raise ImportError("FlashInfer is not installed. Please install it from "
"https://github.com/flashinfer-ai/flashinfer") from None
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
from vllm.utils.flashinfer import use_trtllm_attention
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "FLASHINFER"
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata
@staticmethod
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FlashInferState.get_kv_cache_layout()
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
class FlashInferState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.zeros(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
@staticmethod
def get_kv_cache_layout():
from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
logger.info_once("Using KV cache layout %s",
_KV_CACHE_LAYOUT_OVERRIDE)
return _KV_CACHE_LAYOUT_OVERRIDE
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
logger.info_once("Using default KV cache layout NHD")
return "NHD"
logger.info_once("Using KV cache layout %s", cache_layout)
return cache_layout
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
global_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
max_decode_seq_len=0,
seq_lens_tensor=self._graph_seq_lens,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None,
**dataclasses.asdict(global_params),
)
attn_metadata.begin_forward()
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
return {
"block_tables": attn_metadata.block_tables,
"seq_lens_tensor": attn_metadata.seq_lens_tensor,
"slot_mapping": attn_metadata.slot_mapping,
}
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
# FlashInfer-specific logic: copy additional tensors
num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[
0]
input_buffers["seq_lens_tensor"][:num_total_blocks].copy_(
attn_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"][:num_total_blocks].copy_(
attn_metadata.block_tables, non_blocking=True)
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
is_decode = model_input.attn_metadata.num_prefills == 0
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if use_cuda_graph and is_decode:
if model_input.inputs_embeds is None:
batch_size = model_input.input_tokens.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, False)].attn_state)
else:
batch_size = model_input.inputs_embeds.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, True)].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()
@dataclass
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
max_decode_seq_len: int
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = 1
use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# used for GPU operations
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The number of query/output heads
num_qo_heads: Optional[int] = None
# The number of key/value heads
num_kv_heads: Optional[int] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
# Block size of vllm
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
# FlashInfer 0.2 encourages passing host tensors
device: torch.device = torch.device("cpu")
is_profile_run: bool = False
# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:
# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left: int = -1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap: Optional[float] = None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale: Optional[float] = None
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
def begin_forward(self):
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return
assert self.prefill_wrapper is not None
assert self.query_start_loc is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
self.query_start_loc = self.query_start_loc[:self.num_prefills + 1]
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
# determine the number of blocks. Therefore,
# we don't need to prepare the input for flashinfer for profile run.
if not self.is_profile_run:
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
# handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
# kv-cache data type.
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
if self.num_prefills == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
if self.num_decode_tokens == 0:
return None
return self
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = self.runner.vllm_config
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self.paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
self.total_blocks = 0
self.is_profile_run: bool = False
if self.global_hyperparameters is None:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
self.sm_scale = inferred_params.sm_scale
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping.
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
self.is_profile_run = is_profile_run
return
block_table = block_tables[seq_id]
self._update_paged_kv_tensors(block_table, seq_len)
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
decode_query_len = max(query_lens[self.num_prefills:], default=1)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device="cpu",
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
else:
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata(
decode_query_len=decode_query_len,
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config),
head_dim=self.runner.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASHINFER backend.")
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes = self.alibi_slopes
logits_soft_cap = self.logits_soft_cap
num_tokens, hidden_size = query.shape
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
# QKV for prefill.
decode_query = query[num_prefill_tokens:]
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
window_left = window_size[0] if window_size is not None else -1
prefill_output: Optional[torch.Tensor] = None
if num_decode_tokens > 0:
decode_output = torch.empty(decode_query.shape,
dtype=decode_query.dtype,
device=decode_query.device)
else:
decode_output = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache.numel() == 0:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
)
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
assert prefill_meta.prefill_wrapper._causal
assert prefill_meta.prefill_wrapper._window_left == window_left
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
assert decode_meta.decode_wrapper._window_left == window_left
assert decode_meta.decode_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
# TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer.
if not use_trtllm_attention(
num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
decode_meta.decode_wrapper.run(
decode_query,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=decode_output,
)
else:
workspace_buffer = (
decode_meta.decode_wrapper._float_workspace_buffer)
assert FlashInferState.get_kv_cache_layout() == "HND"
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache.permute(*stride_order),
workspace_buffer=workspace_buffer,
block_tables=attn_metadata.block_tables,
seq_lens=decode_meta.seq_lens_tensor,
max_seq_len=attn_metadata.max_decode_seq_len,
bmm1_scale=layer._k_scale_float * softmax_scale,
bmm2_scale=layer._v_scale_float,
out=decode_output,
)
if prefill_output is None and decode_output is not None:
# Decode only batch.
output, num_tokens = decode_output, num_decode_tokens
elif decode_output is None and prefill_output is not None:
# Prefill only batch.
output, num_tokens = prefill_output, num_prefill_tokens
else:
# Chunked prefill batch does not work with speculative decoding in
# FlashInfer backend, so the query length for decode should be 1.
assert prefill_output is not None
assert decode_output is not None
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
......@@ -839,8 +839,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.context_chunk_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_and_maybe_dequant_cache` kernel cannot
# handle `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
......@@ -1097,6 +1097,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
......@@ -1118,12 +1119,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
ops.gather_cache(
ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_tables,
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills,
kv_cache_dtype=self.kv_cache_dtype,
scale=k_scale,
seq_starts=prefill_metadata.context_chunk_starts[i],
)
......@@ -1180,6 +1183,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
......@@ -1212,7 +1216,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
......@@ -1249,12 +1253,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
......@@ -1302,7 +1307,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, layer._k_scale)
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(
......
......@@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == GroupShape.PER_TENSOR
return quant_key == kFp8StaticTensorSym
# Only supported in the Triton backend
return False
......@@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -585,17 +584,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
use prefill sequence attributes
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
......@@ -606,6 +606,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
......
......@@ -561,7 +561,7 @@ def get_num_prefill_decode_query_kv_tokens(
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
......
......@@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -470,21 +471,22 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
max_encoder_seq_len)
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
......@@ -643,7 +645,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
for API spec.
Args:
output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
......
......@@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
......@@ -54,7 +55,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS
class Attention(nn.Module):
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
......@@ -128,11 +129,17 @@ class Attention(nn.Module):
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
......@@ -183,8 +190,7 @@ class Attention(nn.Module):
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
......@@ -291,6 +297,7 @@ class Attention(nn.Module):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
......@@ -488,6 +495,7 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
......@@ -503,7 +511,8 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......@@ -515,6 +524,7 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
......@@ -522,7 +532,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
......@@ -6,12 +6,13 @@ from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
subclass_attention_backend)
from ..layer import Attention
......@@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size)
return super().build(common_prefix_len, common_attn_metadata,
fast_build)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
builder_cls=ChunkedLocalAttentionBuilder)
return attn_backend
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(common_prefix_len, new_common_attn_metadata,
fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, \
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs)
......@@ -67,6 +67,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -81,6 +83,8 @@ def flash_mla_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
......@@ -98,6 +102,8 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment