Commit 9e053941 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 kernel and _rocm_C extension

parent f850f22a
......@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
# "csrc/quantization/fp8/common.cu"
# "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
......@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
......@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
WITH_SOABI)
endif()
]]
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
......
......@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8"
#"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
......
......@@ -17,43 +17,43 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#else
#define WARP_SIZE warpSize
#endif
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
......@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) {
}
// Parallel reduction inside the warp.
#pragma unroll
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel(
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
......@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel(
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
......@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
......@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
......@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
......@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel(
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
......@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel(
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
......@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel(
}
// Perform reduction within each warp.
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
......@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel(
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
......@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel(
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
......@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel(
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
......@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel(
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel(
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
......@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel(
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel(
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
......@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
......@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel(
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
......@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel(
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
......@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel(
}
from_float(out_ptr[i], acc);
}
}
}
} // namespace vllm
} // namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
......@@ -5,24 +5,26 @@
* Currently, only static fp8 quantization is supported.
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#else
#include <hipcub/hipcub.hpp>
#endif
#endif
namespace vllm {
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel(
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
......@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel(
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}
}
/* Function specialization in the case of FP16/BF16 tensors.
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
......@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel(
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
#pragma unroll
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
}
}
}
}
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
......@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel(
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
}
}
} // namespace vllm
} // namespace vllm
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
......@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size);
});
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
......@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size); \
}); \
});
void fused_add_rms_norm_static_fp8_quant(
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
......@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant(
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
\ No newline at end of file
......@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale,
// double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& residual,
// torch::Tensor& weight,
// torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
......@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
// void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
......
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
......
......@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
// #include "quantization/fp8/common.cuh"
namespace vllm {
......
......@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
// #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
// #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
// __device__ __forceinline__ void atomicAdd(half* address, half val) {
// atomicAdd_half(address, val);
// }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val);
}
#endif
// #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
// __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
// atomicAdd_half2(address, val);
// }
// #endif
#endif
#endif
// #endif
// #endif
} // namespace gptq
} // namespace vllm
......
......@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant);
// ops.def(
// "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
// "Tensor scale, float epsilon) -> "
// "()");
// ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
// &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);
// ops.def(
// "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! residual, Tensor weight, "
// "Tensor scale, float epsilon) -> ()");
// ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
// &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops.def(
......@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// ops.def(
// "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
// "()");
// ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
......
......@@ -643,8 +643,8 @@ ext_modules = []
if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
# if _is_hip():
# ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
......
......@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
# def paged_attention_rocm(
# out: torch.Tensor,
# exp_sum: torch.Tensor,
# max_logits: torch.Tensor,
# tmp_out: torch.Tensor,
# query: torch.Tensor,
# key_cache: torch.Tensor,
# value_cache: torch.Tensor,
# num_kv_heads: int,
# scale: float,
# block_tables: torch.Tensor,
# seq_lens: torch.Tensor,
# block_size: int,
# max_seq_len: int,
# alibi_slopes: Optional[torch.Tensor],
# kv_cache_dtype: str,
# k_scale: torch.Tensor,
# v_scale: torch.Tensor,
# ) -> None:
# torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
# key_cache, value_cache, num_kv_heads,
# scale, block_tables, seq_lens,
# block_size, max_seq_len, alibi_slopes,
# kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
......
......@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
# use_custom = _use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len)
use_custom = False
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment