Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
...@@ -16,33 +16,22 @@ Easy, fast, and cheap LLM serving for everyone ...@@ -16,33 +16,22 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**Ray Summit CPF is Open (June 4th to June 20th)!** **The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)**
There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year! We are excited to announce our fifth vLLM Meetup!
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals. Join us to hear the vLLM's recent updates and the upcoming roadmap.
This will be a great chance for everyone in the community to get together and learn. Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM.
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite) Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
We are thrilled to announce our fourth vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
Please register [here](https://lu.ma/agivllm) and join us!
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
- [2024/01] Added ROCm 6.0 support to vLLM. - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) with a16z! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/12] Added ROCm 5.7 support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM. - [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
--- ---
...@@ -58,14 +47,16 @@ vLLM is fast with: ...@@ -58,14 +47,16 @@ vLLM is fast with:
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
- Optimized CUDA kernels - Optimized CUDA kernels
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
- Seamless integration with popular Hugging Face models - Seamless integration with popular Hugging Face models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism support for distributed inference - Tensor parallelism and pipeline parallelism support for distributed inference
- Streaming outputs - Streaming outputs
- OpenAI-compatible API server - OpenAI-compatible API server
- Support NVIDIA GPUs and AMD GPUs - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs
- (Experimental) Prefix caching support - (Experimental) Prefix caching support
- (Experimental) Multi-lora support - (Experimental) Multi-lora support
...@@ -109,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a ...@@ -109,6 +100,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Databricks - Databricks
- DeepInfra - DeepInfra
- Dropbox - Dropbox
- Google Cloud
- Lambda Lab - Lambda Lab
- NVIDIA - NVIDIA
- Replicate - Replicate
...@@ -118,6 +110,7 @@ vLLM is a community project. Our compute resources for development and testing a ...@@ -118,6 +110,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Trainy - Trainy
- UC Berkeley - UC Berkeley
- UC San Diego - UC San Diego
- ZhenFund
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
......
...@@ -11,7 +11,7 @@ from tqdm import tqdm ...@@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -61,7 +61,7 @@ def main(args: argparse.Namespace): ...@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{ dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
......
...@@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1] ...@@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1]
# helpers # helpers
def to_fp8(tensor: torch.tensor) -> torch.tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor) -> torch.tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.tensor, torch.tensor]: k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5 a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5 b = torch.randn((n, k), device='cuda').t() * 5
...@@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, ...@@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl # impl
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b) return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a, return torch._scaled_mm(a,
b, b,
scale_a=scale_a, scale_a=scale_a,
...@@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, ...@@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
out_dtype=out_dtype) out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.tensor, scale_b: torch.tensor, scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a, return torch._scaled_mm(a,
b, b,
scale_a=scale_a, scale_a=scale_a,
...@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, ...@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
use_fast_accum=True) use_fast_accum=True)
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench # bench
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, out_dtype: torch.dtype, label: str, scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement: sub_label: str, fn: Callable, description: str) -> TMeasurement:
min_run_time = 1 min_run_time = 1
......
...@@ -100,7 +100,7 @@ def main( ...@@ -100,7 +100,7 @@ def main(
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale # Using default kv_scale
kv_scale = 1.0 k_scale = v_scale = 1.0
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
...@@ -117,7 +117,8 @@ def main( ...@@ -117,7 +117,8 @@ def main(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
elif version == "v2": elif version == "v2":
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -136,7 +137,8 @@ def main( ...@@ -136,7 +137,8 @@ def main(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
......
...@@ -122,9 +122,9 @@ __device__ void paged_attention_kernel( ...@@ -122,9 +122,9 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y; const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.y;
...@@ -511,7 +511,7 @@ __device__ void paged_attention_kernel( ...@@ -511,7 +511,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale); v_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -636,17 +636,18 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( ...@@ -636,17 +636,18 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_head_sliding_step) { const int blocksparse_local_blocks, const int blocksparse_vert_stride,
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
blocksparse_vert_stride, blocksparse_block_size, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_head_sliding_step); blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
...@@ -675,16 +676,17 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( ...@@ -675,16 +676,17 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_head_sliding_step) { const int blocksparse_local_blocks, const int blocksparse_vert_stride,
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
...@@ -808,7 +810,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( ...@@ -808,7 +810,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
...@@ -830,8 +832,8 @@ void paged_attention_v1_launcher( ...@@ -830,8 +832,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -892,7 +894,7 @@ void paged_attention_v1_launcher( ...@@ -892,7 +894,7 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); blocksparse_block_size, blocksparse_head_sliding_step);
...@@ -942,8 +944,8 @@ void paged_attention_v1( ...@@ -942,8 +944,8 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
...@@ -960,7 +962,7 @@ void paged_attention_v1( ...@@ -960,7 +962,7 @@ void paged_attention_v1(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
...@@ -977,8 +979,8 @@ void paged_attention_v2_launcher( ...@@ -977,8 +979,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -1044,8 +1046,9 @@ void paged_attention_v2_launcher( ...@@ -1044,8 +1046,9 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_block_size, blocksparse_head_sliding_step); blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -1097,8 +1100,8 @@ void paged_attention_v2( ...@@ -1097,8 +1100,8 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
......
...@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches, ...@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const double k_scale,
const double kv_scale); const double v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
......
...@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( ...@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size] // block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x, const int head_size, const int block_size, const int x, const float k_scale,
const float kv_scale) { const float v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
...@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( ...@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_value_idx] = value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
} }
} }
} }
...@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale); num_heads, head_size, block_size, x, k_scale, v_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
...@@ -258,7 +258,8 @@ void reshape_and_cache( ...@@ -258,7 +258,8 @@ void reshape_and_cache(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const double kv_scale) { const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
...@@ -318,13 +319,13 @@ namespace vllm { ...@@ -318,13 +319,13 @@ namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache, Tout* __restrict__ dst_cache,
const float kv_scale, const float scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
dst_cache[idx] = dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
} }
} }
...@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, ...@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride); reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
// Only for testing. // Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double kv_scale, const std::string& kv_cache_dtype) { const double scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
......
...@@ -423,11 +423,11 @@ void paged_attention_v1( ...@@ -423,11 +423,11 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
...@@ -742,11 +742,11 @@ void paged_attention_v2( ...@@ -742,11 +742,11 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
......
...@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches, ...@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, double kv_scale) { const std::string& kv_cache_dtype, double k_scale,
TORCH_CHECK(kv_scale == 1.0f); double v_scale) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
......
...@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
...@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
...@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache," " Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping," " Tensor slot_mapping,"
" str kv_cache_dtype," " str kv_cache_dtype,"
" float kv_scale) -> ()"); " float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
} }
......
...@@ -8,8 +8,8 @@ void paged_attention_v1( ...@@ -8,8 +8,8 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -19,8 +19,8 @@ void paged_attention_v2( ...@@ -19,8 +19,8 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -59,6 +59,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); ...@@ -59,6 +59,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& codebooks,
...@@ -91,15 +96,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -91,15 +96,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int64_t size_k); int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx, torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& perm, torch::Tensor& workspace, torch::Tensor& g_idx, torch::Tensor& perm,
int64_t num_bits, int64_t size_m, int64_t size_n, torch::Tensor& workspace, int64_t num_bits,
int64_t size_k, bool is_k_full); int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits); int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
...@@ -132,12 +141,16 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); ...@@ -132,12 +141,16 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col); void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale); // torch::Tensor const& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale); // torch::Tensor& scale);
// void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// c10::optional<torch::Tensor> const& scale_ub);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
......
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/
#include "advance_step.cuh"
namespace prepare_inputs {
//
template <int const num_threads>
__global__ void advance_step_kernel(int num_seqs, int num_queries,
int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr,
long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr,
int64_t const block_tables_stride) {
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) {
return;
}
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
if (cur_query_id >= num_queries) {
return;
}
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;
// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;
int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}
inline void verify_tensor(std::string const& name, torch::Tensor& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
if (size_0 != -1) {
size_0_cond = t.size(0) == size_0;
}
bool size_1_cond = true;
if (size_1 != -1) {
size_1_cond = t.size(1) == size_1;
}
bool is_contiguous = t.is_contiguous();
bool same_type = t.dtype() == type;
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
if (!pass) {
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
" is not as expected: shape = [", size_0, ", ", size_1,
"], type = ", type);
}
}
void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) { // type: int
if (logging) {
printf("advance_step:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0));
}
} // namespace prepare_inputs
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
}
\ No newline at end of file
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace prepare_inputs {
static constexpr int max_threads = 256;
static constexpr bool logging = false;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
} // namespace prepare_inputs
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "../../reduction_utils.cuh"
namespace vllm { namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
...@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { ...@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <typename scalar_t> template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) { float const val, float const scale) {
float x = static_cast<float>(val) * inverted_scale; float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<c10::Float8_e4m3fn>(r);
} }
...@@ -87,6 +95,70 @@ typedef struct __align__(4) { ...@@ -87,6 +95,70 @@ typedef struct __align__(4) {
} }
float8x4_t; float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
template <typename scalar_t> template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, ...@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive // Invert the scale so that we can use multiplications to avoid expensive
// division. // division.
const float inverted_scale = 1.0f / (*scale); const float inverted_scale = 1.0f / (*scale);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}
// Vectorized input/output to better utilize memory bandwidth. template <typename scalar_t>
const vec4_t<scalar_t>* vectorized_in = __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
reinterpret_cast<const vec4_t<scalar_t>*>(input); c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out); scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
int num_vec_elems = num_elems >> 2; int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
#pragma unroll 4 scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale); // For vectorization, token_input and token_output pointers need to be
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale); // aligned at 8-byte and 4-byte addresses respectively.
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale); bool const can_vectorize = hidden_size % 4 == 0;
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
vectorized_out[i] = out_vec; float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
}
} }
// Handle the remaining elements if num_elems is not divisible by 4 float const block_absmax_val_maybe = blockReduceMax(absmax_val);
for (int i = num_vec_elems * 4 + tid; i < num_elems; __shared__ float token_scale;
i += blockDim.x * gridDim.x) { if (tid == 0) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale); if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
} }
} }
} // namespace vllm } // namespace vllm
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor const& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel(); int64_t num_elems = input.numel();
...@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
}); });
} }
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1] torch::Tensor& scale) // [1]
{ {
int64_t num_tokens = input.numel() / input.size(-1); int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel(); int64_t num_elems = input.numel();
...@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale.data_ptr<float>(), num_elems); scale.data_ptr<float>(), num_elems);
}); });
} }
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#include "../gptq_marlin/gptq_marlin.cuh" #include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/gptq_marlin_dtypes.cuh" #include "../gptq_marlin/marlin_dtypes.cuh"
using namespace gptq_marlin; using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
...@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k); ", size_k = ", size_k);
// Verify B // Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n); ", actual_size_n = ", actual_size_n);
...@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups = b_scales.size(0); num_groups = b_scales.size(0);
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", marlin::min_thread_n);
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(), "workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
...@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k, b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev, workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>( fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m, c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else { } else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
} }
......
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin {
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
#endif
#include "gptq_marlin.cuh" #include "marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {} int size_k, int size_n) {}
} // namespace gptq_marlin } // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
...@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else #else
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
...@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel( ...@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
} }
} }
} // namespace gptq_marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
NUM_BITS, HAS_PERM>, \ HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); " is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
...@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype()) .dtype(b_q_weight.dtype())
.device(b_q_weight.device()); .device(b_q_weight.device());
torch::Tensor out = torch::Tensor out = torch::empty(
torch::empty({size_k / gptq_marlin::tile_size, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
size_n * gptq_marlin::tile_size / pack_factor}, options);
options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
namespace gptq_marlin { namespace marlin {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more // 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
...@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64; ...@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
// Repack params
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
// Helpers
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
...@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() { ...@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
#endif #endif
} // namespace gptq_marlin } // namespace marlin
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "gptq_marlin.cuh" #include "marlin.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace gptq_marlin { namespace marlin {
template <typename scalar_t> template <typename scalar_t>
class ScalarType {}; class ScalarType {};
...@@ -23,6 +23,7 @@ class ScalarType<half> { ...@@ -23,6 +23,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
return __half2float(x); return __half2float(x);
...@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> { ...@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { static __device__ float inline num2float(const nv_bfloat16 x) {
...@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> { ...@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
#endif #endif
}; };
} // namespace gptq_marlin } // namespace marlin
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment