"docs/vscode:/vscode.git/clone" did not exist on "8ceffbf3152d3b26d293ba1e157d0c187884572b"
Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
import argparse
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
# Marlin quant
(
marlin_w_ref,
marlin_q_w,
marlin_s,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
if act_order:
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
}
min_run_time = 1
# Warmup pytorch
for i in range(5):
torch.matmul(a, marlin_w_ref)
results.append(
benchmark.Timer(
stmt="torch.matmul(a, marlin_w_ref)",
globals=globals,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results = []
for model in args.models:
for layer in WEIGHT_SHAPES[model]:
size_k = layer[0]
size_n = layer[1]
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
compare = benchmark.Compare(results)
compare.print()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)
...@@ -170,7 +170,7 @@ if __name__ == '__main__': ...@@ -170,7 +170,7 @@ if __name__ == '__main__':
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
...@@ -183,13 +183,11 @@ if __name__ == '__main__': ...@@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type. ' "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",
......
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}
...@@ -4,7 +4,7 @@ PORT=8000 ...@@ -4,7 +4,7 @@ PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \ docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.4.0 \ ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \ --model-id $MODEL \
......
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
...@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif() endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS list(REMOVE_ITEM GPU_FLAGS
...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
# "-DENABLE_FP8_E4M3" # "-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc" "-fno-gpu-rdc"
......
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
namespace vllm { namespace vllm {
// Activation and gating kernel template. // Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( ...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
} }
} }
template<typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x; const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2; constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x; const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715; constexpr float KAPPA = 0.044715;
float x_cube = f * f * f; float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube); float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \ });
input.data_ptr<scalar_t>(), \
d); \ void silu_and_mul(torch::Tensor& out, // [..., d]
}); torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul( ...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
...@@ -108,54 +102,49 @@ __global__ void activation_kernel( ...@@ -108,54 +102,49 @@ __global__ void activation_kernel(
} }
} }
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ });
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm { namespace vllm {
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) { __device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x); const float x3 = (float)(x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
return ((T) 0.5) * x * (((T) 1.0) + t); (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -22,31 +23,31 @@ ...@@ -22,31 +23,31 @@
namespace vllm { namespace vllm {
// A vector type to store Q, K, V elements. // A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE> template <typename T, int VEC_SIZE>
struct Vec {}; struct Vec {};
// A vector type to store FP32 accumulators. // A vector type to store FP32 accumulators.
template<typename T> template <typename T>
struct FloatVec {}; struct FloatVec {};
// Template vector operations. // Template vector operations.
template<typename Acc, typename A, typename B> template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b); inline __device__ Acc mul(A a, B b);
template<typename T> template <typename T>
inline __device__ float sum(T v); inline __device__ float sum(T v);
template<typename T> template <typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b)); return sum(mul<T, T, T>(a, b));
} }
template<typename A, typename T> template <typename A, typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b)); return sum(mul<A, T, T>(a, b));
} }
template<typename T> template <typename T>
inline __device__ void zero(T& dst) { inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4; constexpr int WORDS = sizeof(T) / 4;
union { union {
...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { ...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw; dst = tmp.raw;
} }
} // namespace vllm } // namespace vllm
This diff is collapsed.
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -26,7 +27,7 @@ ...@@ -26,7 +27,7 @@
namespace vllm { namespace vllm {
// Q*K^T operation. // Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately). // Compute the parallel products for Q*K^T (treat vector lanes separately).
...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template<typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template<typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
}; };
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -28,8 +30,8 @@ ...@@ -28,8 +30,8 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#include <stdint.h> #include <stdint.h>
...@@ -50,37 +52,37 @@ struct bf16_8_t { ...@@ -50,37 +52,37 @@ struct bf16_8_t {
}; };
// BF16 vector types for Q, K, V. // BF16 vector types for Q, K, V.
template<> template <>
struct Vec<__nv_bfloat16, 1> { struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16; using Type = __nv_bfloat16;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 2> { struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 4> { struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t; using Type = bf16_4_t;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 8> { struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t; using Type = bf16_8_t;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<__nv_bfloat16> { struct FloatVec<__nv_bfloat16> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<__nv_bfloat162> { struct FloatVec<__nv_bfloat162> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<bf16_4_t> { struct FloatVec<bf16_4_t> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<bf16_8_t> { struct FloatVec<bf16_8_t> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
// assert(false); // assert(false);
// #else // #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
// #endif // #endif
} }
...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { ...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
// #endif // #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { ...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
// #endif // #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c; bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { ...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_4_t c; bf16_4_t c;
...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c; bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { ...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_8_t c; bf16_8_t c;
...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a); float fa = __bfloat162float(a);
float fb = __bfloat162float(b); float fb = __bfloat162float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a); float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b); float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { ...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float4_ fc; Float4_ fc;
...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { ...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float8_ fc; Float8_ fc;
...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
// #else // #else
...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf ...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
// #endif // #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
// #else // #else
...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { ...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(__nv_bfloat16 v) { inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v); return __bfloat162float(v);
} }
template<> template <>
inline __device__ float sum(__nv_bfloat162 v) { inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v); float2 vf = bf1622float2(v);
return vf.x + vf.y; return vf.x + vf.y;
} }
template<> template <>
inline __device__ float sum(bf16_4_t v) { inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y); return sum(v.x) + sum(v.y);
} }
template<> template <>
inline __device__ float sum(bf16_8_t v) { inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
} }
...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { ...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
// #endif // #endif
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -30,37 +32,37 @@ ...@@ -30,37 +32,37 @@
namespace vllm { namespace vllm {
// FP16 vector types for Q, K, V. // FP16 vector types for Q, K, V.
template<> template <>
struct Vec<uint16_t, 1> { struct Vec<uint16_t, 1> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint16_t, 2> { struct Vec<uint16_t, 2> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint16_t, 4> { struct Vec<uint16_t, 4> {
using Type = uint2; using Type = uint2;
}; };
template<> template <>
struct Vec<uint16_t, 8> { struct Vec<uint16_t, 8> {
using Type = uint4; using Type = uint4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<uint16_t> { struct FloatVec<uint16_t> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<uint32_t> { struct FloatVec<uint32_t> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<uint2> { struct FloatVec<uint2> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<uint4> { struct FloatVec<uint4> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { ...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b; return b;
#else #else
union { union {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
tmp.u16[0] = a; tmp.u16[0] = a;
tmp.u16[1] = a; tmp.u16[1] = a;
...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { ...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else #else
tmp.u16[0] = float_to_half(f.x); tmp.u16[0] = float_to_half(f.x);
...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { ...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { ...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { ...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) { inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ uint2 mul(uint2 a, uint2 b) { inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c; uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { ...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) { inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint2 c; uint2 c;
...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { ...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint4 a, uint4 b) { inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c; uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { ...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) { inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint4 c; uint4 c;
...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { ...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(uint16_t a, uint16_t b) { inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a); float fa = half_to_float(a);
float fb = half_to_float(b); float fb = half_to_float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) { inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b); float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) { inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) { inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { ...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) { inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float4_ fc; Float4_ fc;
...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { ...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) { inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { ...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float8_ fc; Float8_ fc;
...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { ...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { ...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(uint16_t v) { inline __device__ float sum(uint16_t v) {
return half_to_float(v); return half_to_float(v);
} }
template<> template <>
inline __device__ float sum(uint32_t v) { inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v); float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y; return tmp.x + tmp.y;
} }
template<> template <>
inline __device__ float sum(uint2 v) { inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
return sum(c); return sum(c);
} }
template<> template <>
inline __device__ float sum(uint4 v) { inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
c = add(c, v.z); c = add(c, v.z);
...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { ...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { ...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -38,37 +40,35 @@ struct Float8_ { ...@@ -38,37 +40,35 @@ struct Float8_ {
}; };
// FP32 vector types for Q, K, V. // FP32 vector types for Q, K, V.
template<> template <>
struct Vec<float, 1> { struct Vec<float, 1> {
using Type = float; using Type = float;
}; };
template<> template <>
struct Vec<float, 2> { struct Vec<float, 2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct Vec<float, 4> { struct Vec<float, 4> {
using Type = float4; using Type = float4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<float> { struct FloatVec<float> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<float2> { struct FloatVec<float2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<float4> { struct FloatVec<float4> {
using Type = float4; using Type = float4;
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { ...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ float mul<float, float>(float a, float b) { inline __device__ float mul<float, float>(float a, float b) {
return a * b; return a * b;
} }
template<> template <>
inline __device__ float2 mul(float2 a, float2 b) { inline __device__ float2 mul(float2 a, float2 b) {
float2 c; float2 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { ...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float2 mul(float a, float2 b) { inline __device__ float2 mul(float a, float2 b) {
float2 c; float2 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { ...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float4 a, float4 b) { inline __device__ float4 mul(float4 a, float4 b) {
float4 c; float4 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { ...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float a, float4 b) { inline __device__ float4 mul(float a, float4 b) {
float4 c; float4 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { ...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { ...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(float v) { inline __device__ float sum(float v) {
return v; return v;
} }
template<> template <>
inline __device__ float sum(float2 v) { inline __device__ float sum(float2 v) {
return v.x + v.y; return v.x + v.y;
} }
template<> template <>
inline __device__ float sum(float4 v) { inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w; return v.x + v.y + v.z + v.w;
} }
template<> template <>
inline __device__ float sum(Float4_ v) { inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y; return v.x.x + v.x.y + v.y.x + v.y.y;
} }
template<> template <>
inline __device__ float sum(Float8_ v) { inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { ...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm
...@@ -3,33 +3,39 @@ ...@@ -3,33 +3,39 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include <stdint.h> #include <stdint.h>
#ifdef ENABLE_FP8_E5M2 #ifdef ENABLE_FP8
#include <cuda_fp8.h> #ifndef USE_ROCM
#endif #include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm { namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template<> enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
using Type = uint8_t; using Type = uint8_t;
}; };
template<> template <>
struct Vec<uint8_t, 2> { struct Vec<uint8_t, 2> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint8_t, 4> { struct Vec<uint8_t, 4> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint8_t, 8> { struct Vec<uint8_t, 8> {
using Type = uint2; using Type = uint2;
}; };
#endif // ENABLE_FP8_E5M2
} // namespace vllm } // namespace vllm
...@@ -5,34 +5,24 @@ ...@@ -5,34 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, const float kv_scale);
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache,
torch::Tensor& value, torch::Tensor& value_cache,
torch::Tensor& key_cache, torch::Tensor& slot_mapping,
torch::Tensor& value_cache, const std::string& kv_cache_dtype);
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& src_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& dst_cache);
This diff is collapsed.
#include "cpu_types.hpp" #include "cpu_types.hpp"
namespace { namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated> bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t *__restrict__ output) { scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, ...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
} }
} }
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp()); return x / (ones + (zeros - x).exp());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { ...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { ...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er()); return x * w2 * (ones + (x * w1).er());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { ...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh()); return x * w2 * (ones + inner.tanh());
} }
}; // namespace }; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) activation_kernel<scalar_t, silu_act, true>(
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
} }
void gelu_and_mul(torch::Tensor &out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) activation_kernel<scalar_t, gelu_act, true>(
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
} }
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] ...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
}); });
} }
void gelu_new(torch::Tensor &out, torch::Tensor &input) { void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { ...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
}); });
} }
void gelu_fast(torch::Tensor &out, torch::Tensor &input) { void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment