Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
import argparse
import json import json
import os import os
import sys import sys
...@@ -5,6 +6,7 @@ import sys ...@@ -5,6 +6,7 @@ import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import triton import triton
from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import (fused_moe, from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name) get_config_file_name)
...@@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe, ...@@ -12,16 +14,16 @@ from vllm.model_executor.layers.fused_moe import (fused_moe,
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main(): def main(dtype: str):
method = fused_moe method = fused_moe
for bs in [ for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096 2048, 3072, 4096
]: ]:
run_grid(bs, method=method) run_grid(bs, method=method, dtype=dtype)
def run_grid(bs, method): def run_grid(bs, method, dtype: str):
d_model = 4096 d_model = 4096
num_total_experts = 8 num_total_experts = 8
top_k = 2 top_k = 2
...@@ -34,39 +36,29 @@ def run_grid(bs, method): ...@@ -34,39 +36,29 @@ def run_grid(bs, method):
num_trials = 1 num_trials = 1
configs = [] configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]: for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M: for block_size_m in [16, 32, 64, 128, 256]:
for block_size_k in [64, 128, 256]: for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]: for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]: for num_warps in [4, 8]:
for num_stages in [2, 3, 4, 5]:
configs.append({ configs.append({
"BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k, "BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m, "GROUP_SIZE_M": group_size_m,
"num_warps": num_warps, "num_warps": num_warps,
"num_stages": 4, "num_stages": num_stages,
}) })
best_config = None best_config = None
best_time_us = 1e20 best_time_us = 1e20
for config in configs:
print(f'{tp_size=} {bs=}') print(f'{tp_size=} {bs=}')
print(f'{config}')
for config in tqdm(configs):
# warmup # warmup
print('warming up')
try: try:
for _ in range(num_warmup_trials): for _ in range(num_warmup_trials):
run_timing( run_timing(
...@@ -79,12 +71,12 @@ def run_grid(bs, method): ...@@ -79,12 +71,12 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size, model_intermediate_size=model_intermediate_size,
method=method, method=method,
config=config, config=config,
dtype=dtype,
) )
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
continue continue
# trial # trial
print('benchmarking')
for _ in range(num_trials): for _ in range(num_trials):
kernel_dur_ms = run_timing( kernel_dur_ms = run_timing(
num_calls=num_calls, num_calls=num_calls,
...@@ -96,6 +88,7 @@ def run_grid(bs, method): ...@@ -96,6 +88,7 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size, model_intermediate_size=model_intermediate_size,
method=method, method=method,
config=config, config=config,
dtype=dtype,
) )
kernel_dur_us = 1000 * kernel_dur_ms kernel_dur_us = 1000 * kernel_dur_ms
...@@ -105,7 +98,8 @@ def run_grid(bs, method): ...@@ -105,7 +98,8 @@ def run_grid(bs, method):
best_config = config best_config = config
best_time_us = kernel_dur_us best_time_us = kernel_dur_us
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' tqdm.write(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}') f'{d_model=} {model_intermediate_size=} {num_layers=}')
...@@ -114,7 +108,8 @@ def run_grid(bs, method): ...@@ -114,7 +108,8 @@ def run_grid(bs, method):
# holds Dict[str, Dict[str, int]] # holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts, filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size) model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None)
print(f"writing config to file {filename}") print(f"writing config to file {filename}")
existing_content = {} existing_content = {}
if os.path.exists(filename): if os.path.exists(filename):
...@@ -128,27 +123,48 @@ def run_grid(bs, method): ...@@ -128,27 +123,48 @@ def run_grid(bs, method):
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method, top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float: config, dtype: str) -> float:
shard_intermediate_size = model_intermediate_size // tp_size shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand( hidden_states = torch.rand(
(bs, d_model), (bs, d_model),
device="cuda:0", device="cuda:0",
dtype=torch.bfloat16, dtype=torch.float16,
) )
ws = torch.rand( w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model), (num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
w2s = torch.rand( w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size), (num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
w2_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
a1_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
a2_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
gating_output = F.softmax(torch.rand( gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts), (num_calls, bs, num_total_experts),
device=hidden_states.device, device=hidden_states.device,
...@@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, ...@@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for i in range(num_calls): for i in range(num_calls):
hidden_states = method( hidden_states = method(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=ws, w1=w1,
w2=w2s, w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[i], gating_output=gating_output[i],
topk=2, topk=2,
renormalize=True, renormalize=True,
inplace=True, inplace=True,
override_config=config, override_config=config,
use_fp8=dtype == "float8",
) )
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
...@@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, ...@@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) parser = argparse.ArgumentParser(
prog='benchmark_mixtral_moe',
description='Benchmark and tune the fused_moe kernel',
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations',
)
args = parser.parse_args()
sys.exit(main(args.dtype))
...@@ -16,7 +16,7 @@ PARTITION_SIZE = 512 ...@@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main( def main(
version: str, version: str,
num_seqs: int, num_seqs: int,
context_len: int, seq_len: int,
num_query_heads: int, num_query_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
...@@ -48,12 +48,12 @@ def main( ...@@ -48,12 +48,12 @@ def main(
dtype=torch.float, dtype=torch.float,
device=device) device=device)
context_lens = [context_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
...@@ -77,8 +77,7 @@ def main( ...@@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
if version == "v2": if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
...@@ -110,9 +109,9 @@ def main( ...@@ -110,9 +109,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -129,9 +128,9 @@ def main( ...@@ -129,9 +128,9 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, kv_scale,
...@@ -166,7 +165,7 @@ if __name__ == '__main__': ...@@ -166,7 +165,7 @@ if __name__ == '__main__':
choices=["v1", "v2"], choices=["v1", "v2"],
default="v2") default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
...@@ -199,7 +198,7 @@ if __name__ == '__main__': ...@@ -199,7 +198,7 @@ if __name__ == '__main__':
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
context_len=args.context_len, seq_len=args.seq_len,
num_query_heads=args.num_query_heads, num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads, num_kv_heads=args.num_kv_heads,
head_size=args.head_size, head_size=args.head_size,
......
...@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( ...@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( ...@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
...@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( ...@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len; const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
...@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( ...@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
} else { } else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} }
if (block_idx == num_context_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs. // we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
...@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, /* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
} }
...@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( ...@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int q_stride,
...@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( ...@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale); q_stride, kv_block_stride, kv_head_stride, kv_scale);
} }
...@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) { const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
...@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \ alibi_slopes_ptr, \
q_stride, \ q_stride, \
...@@ -639,8 +639,8 @@ void paged_attention_v1_launcher( ...@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -664,11 +664,11 @@ void paged_attention_v1_launcher( ...@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here! // Keep that in sync with the logic here!
...@@ -715,8 +715,8 @@ void paged_attention_v1_launcher( ...@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ seq_lens, \
max_context_len, \ max_seq_len, \
alibi_slopes, \ alibi_slopes, \
kv_scale); kv_scale);
...@@ -746,9 +746,9 @@ void paged_attention_v1( ...@@ -746,9 +746,9 @@ void paged_attention_v1(
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale) { float kv_scale) {
...@@ -790,7 +790,7 @@ void paged_attention_v1( ...@@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \ alibi_slopes_ptr, \
q_stride, \ q_stride, \
...@@ -803,7 +803,7 @@ void paged_attention_v1( ...@@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \ exp_sums_ptr, \
max_logits_ptr, \ max_logits_ptr, \
tmp_out_ptr, \ tmp_out_ptr, \
context_lens_ptr, \ seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template< template<
...@@ -824,8 +824,8 @@ void paged_attention_v2_launcher( ...@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -852,10 +852,10 @@ void paged_attention_v2_launcher( ...@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float); int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
...@@ -909,8 +909,8 @@ void paged_attention_v2_launcher( ...@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \ num_kv_heads, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ seq_lens, \
max_context_len, \ max_seq_len, \
alibi_slopes, \ alibi_slopes, \
kv_scale); kv_scale);
...@@ -943,9 +943,9 @@ void paged_attention_v2( ...@@ -943,9 +943,9 @@ void paged_attention_v2(
int num_kv_heads, // [num_heads] int num_kv_heads, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale) { float kv_scale) {
......
...@@ -24,6 +24,14 @@ void reshape_and_cache( ...@@ -24,6 +24,14 @@ void reshape_and_cache(
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
const float kv_scale); const float kv_scale);
void reshape_and_cache_flash(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(
torch::Tensor& src_cache, torch::Tensor& src_cache,
......
...@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel( ...@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
} }
} }
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm } // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
...@@ -275,6 +310,51 @@ void reshape_and_cache( ...@@ -275,6 +310,51 @@ void reshape_and_cache(
} }
} }
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_flash",
[&] {
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(),
v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
}
namespace vllm { namespace vllm {
template<typename Tout, typename Tin> template<typename Tout, typename Tin>
......
...@@ -70,11 +70,11 @@ template <typename T> ...@@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity, reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope, const int start_index,
const int context_len) { const int seq_len) {
data[0] += alibi_slope * (start_index - context_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk; data[i] = qk;
max = max >= qk ? max : qk; max = max >= qk ? max : qk;
} }
...@@ -225,7 +225,7 @@ struct paged_attention_v1_impl { ...@@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
...@@ -235,32 +235,32 @@ struct paged_attention_v1_impl { ...@@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert(BLOCK_SIZE == 16); static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads(); const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_len - (block_num - 1) * BLOCK_SIZE; seq_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits = float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
...@@ -278,11 +278,11 @@ struct paged_attention_v1_impl { ...@@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax // Compute softmax
if (alibi_slopes) { if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len, reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, context_len, reduceSoftmax(thread_block_logits, seq_len,
block_num * BLOCK_SIZE); block_num * BLOCK_SIZE);
} }
...@@ -340,7 +340,7 @@ struct paged_attention_v1_impl { ...@@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
...@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE> ...@@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, torch::Tensor &block_tables, torch::Tensor &seq_lens,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( ...@@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
...@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( ...@@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
...@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, ...@@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size, torch::Tensor &seq_lens, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
...@@ -448,7 +448,7 @@ struct paged_attention_v2_impl { ...@@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
...@@ -465,22 +465,22 @@ struct paged_attention_v2_impl { ...@@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions; for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) { ++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len) if (start_token_idx >= seq_len)
continue; continue;
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1); const bool no_reduce = (partition_num == 1);
const int context_token_num = const int token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num =
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
...@@ -507,10 +507,10 @@ struct paged_attention_v2_impl { ...@@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum; std::pair<float, float> max_and_sum;
if (alibi_slopes) { if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi( max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, context_token_num, max_and_sum = reduceSoftmax(logits, token_num,
block_num * BLOCK_SIZE); block_num * BLOCK_SIZE);
} }
...@@ -583,9 +583,9 @@ struct paged_attention_v2_impl { ...@@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1) #pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1)
continue; continue;
...@@ -612,9 +612,9 @@ struct paged_attention_v2_impl { ...@@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1)
continue; continue;
...@@ -649,7 +649,7 @@ struct paged_attention_v2_impl { ...@@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
...@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( ...@@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( ...@@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
...@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( ...@@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, \
max_context_len, alibi_slopes); max_seq_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
...@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, ...@@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables, float scale, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size, torch::Tensor &seq_lens, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
......
...@@ -10,9 +10,9 @@ void paged_attention_v1( ...@@ -10,9 +10,9 @@ void paged_attention_v1(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale); float kv_scale);
...@@ -28,9 +28,9 @@ void paged_attention_v2( ...@@ -28,9 +28,9 @@ void paged_attention_v2(
int num_kv_heads, int num_kv_heads,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& seq_lens,
int block_size, int block_size,
int max_context_len, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
float kv_scale); float kv_scale);
...@@ -124,6 +124,26 @@ torch::Tensor marlin_gemm( ...@@ -124,6 +124,26 @@ torch::Tensor marlin_gemm(
int64_t size_m, int64_t size_m,
int64_t size_n, int64_t size_n,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_gemm(
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(
...@@ -146,7 +166,12 @@ void gptq_shuffle( ...@@ -146,7 +166,12 @@ void gptq_shuffle(
torch::Tensor q_perm, torch::Tensor q_perm,
int bit); int bit);
void scaled_fp8_quant( void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
void dynamic_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
...@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py // and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig // Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
...@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on // clang-format on
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
...@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4; constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) { if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0); static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size; constexpr int tx = feat_in / vec_size;
...@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale); int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T) INST_BGMV(wide, narrow, in_T, out_T, W_T)
...@@ -10,6 +10,7 @@ TEMPLATE = """ ...@@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501 """.lstrip() # noqa: E501
for input_dtype in DTYPES: for input_dtype in DTYPES:
......
...@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, ...@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _) FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE #undef CASE
#undef CASE_ONESIDE #undef CASE_ONESIDE
default: default:
return false; return false;
} }
return true; return true;
} }
......
...@@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def( ops.def(
"moe_align_block_size", "moe_align_block_size",
&moe_align_block_size, &moe_align_block_size,
...@@ -93,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -93,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache", "reshape_and_cache",
&reshape_and_cache, &reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def(
"convert_fp8", "convert_fp8",
&convert_fp8, &convert_fp8,
......
...@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel( ...@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
} // namespace vllm } // namespace vllm
void scaled_fp8_quant( void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}
void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d] torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment