"vscode:/vscode.git/clone" did not exist on "43524b61b18e3e9ef1357ba52a0e0bbd31303438"
Commit 51679bbd authored by zhuwenwen's avatar zhuwenwen
Browse files

resolve merge confilcts

parents 4095d0db 1af090b5
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__
template <typename float_t, size_t vec_size>
struct vec_t {
FLASHINFER_INLINE float_t &operator[](size_t i);
FLASHINFER_INLINE const float_t &operator[](size_t i) const;
FLASHINFER_INLINE void fill(float_t val);
FLASHINFER_INLINE void load(const float_t *ptr);
FLASHINFER_INLINE void store(float_t *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src);
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr);
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const;
FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src);
};
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<src_float_t, vec_size> &src,
vec_t<tgt_float_t, vec_size> &dst) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = tgt_float_t(src[i]);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr,
vec_t<tgt_float_t, vec_size> &dst) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
dst.load(src_ptr);
} else {
vec_t<src_float_t, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_store_impl(const vec_t<src_float_t, vec_size> &src,
tgt_float_t *dst_ptr) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
src.store(dst_ptr);
} else {
vec_t<tgt_float_t, vec_size> tmp;
tmp.cast_from(src);
tmp.store(dst_ptr);
}
}
#ifdef FLASHINFER_USE_FP8
/******************* vec_t<__nv_fp8_e4m3> *******************/
// __nv_fp8_e4m3 x 1
template <>
struct vec_t<__nv_fp8_e4m3, 1> {
__nv_fp8_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(
__nv_fp8_e4m3 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*dst = *src;
}
// __nv_fp8_e4m3 x 2
template <>
struct vec_t<__nv_fp8_e4m3, 2> {
__nv_fp8x2_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x2_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x2_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src);
}
// __nv_fp8_e4m3 x 4
template <>
struct vec_t<__nv_fp8_e4m3, 4> {
__nv_fp8x4_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x4_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x4_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src);
}
// __nv_fp8_e4m3 x 8
template <>
struct vec_t<__nv_fp8_e4m3, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(
__nv_fp8_e4m3 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src);
}
// __nv_fp8_e4m3 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e4m3, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<__nv_fp8_e5m2> *******************/
// __nv_fp8_e5m2 x 1
template <>
struct vec_t<__nv_fp8_e5m2, 1> {
__nv_fp8_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(
__nv_fp8_e5m2 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*dst = *src;
}
// __nv_fp8_e5m2 x 2
template <>
struct vec_t<__nv_fp8_e5m2, 2> {
__nv_fp8x2_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x2_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x2_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src);
}
// __nv_fp8_e5m2 x 4
template <>
struct vec_t<__nv_fp8_e5m2, 4> {
__nv_fp8x4_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x4_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x4_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src);
}
// __nv_fp8_e5m2 x 8
template <>
struct vec_t<__nv_fp8_e5m2, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(
__nv_fp8_e5m2 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src);
}
// __nv_fp8_e5m2 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e5m2, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
#endif
/******************* vec_t<half> *******************/
// half x 1
template <>
struct vec_t<half, 1> {
half data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
FLASHINFER_INLINE void vec_t<half, 1>::load(const half *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<half, 1>::store(half *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<half, 1>::memcpy(half *dst, const half *src) {
*dst = *src;
}
// half x 2
template <>
struct vec_t<half, 2> {
half2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 2>::fill(half val) {
data = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 2>::load(const half *ptr) {
data = *((half2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 2>::store(half *ptr) const {
*((half2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 2>::memcpy(half *dst, const half *src) {
*((half2 *)dst) = *((half2 *)src);
}
// half x 4
template <>
struct vec_t<half, 4> {
uint2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 4>::fill(half val) {
*(half2 *)(&data.x) = make_half2(val, val);
*(half2 *)(&data.y) = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 4>::load(const half *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 4>::store(half *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 4>::memcpy(half *dst, const half *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// half x 8 or more
template <size_t vec_size>
struct vec_t<half, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)data)[i];
}
FLASHINFER_INLINE void fill(half val) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
*(half2 *)(&(data[i].x)) = make_half2(val, val);
*(half2 *)(&(data[i].y)) = make_half2(val, val);
*(half2 *)(&(data[i].z)) = make_half2(val, val);
*(half2 *)(&(data[i].w)) = make_half2(val, val);
}
}
FLASHINFER_INLINE void load(const half *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(half *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<nv_bfloat16> *******************/
// nv_bfloat16 x 1
template <>
struct vec_t<nv_bfloat16, 1> {
nv_bfloat16 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*dst = *src;
}
// nv_bfloat16 x 2
template <>
struct vec_t<nv_bfloat16, 2> {
nv_bfloat162 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
data = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16 *ptr) {
data = *((nv_bfloat162 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16 *ptr) const {
*((nv_bfloat162 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src);
}
// nv_bfloat16 x 4
template <>
struct vec_t<nv_bfloat16, 4> {
uint2 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
*(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// nv_bfloat16 x 8 or more
template <size_t vec_size>
struct vec_t<nv_bfloat16, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
*(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val);
}
}
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<float> *******************/
// float x 1
template <>
struct vec_t<float, 1> {
float data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
FLASHINFER_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
*dst = *src;
}
// float x 2
template <>
struct vec_t<float, 2> {
float2 data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 2>::fill(float val) {
data = make_float2(val, val);
}
FLASHINFER_INLINE void vec_t<float, 2>::load(const float *ptr) {
data = *((float2 *)ptr);
}
FLASHINFER_INLINE void vec_t<float, 2>::store(float *ptr) const {
*((float2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(data))[i];
}
FLASHINFER_INLINE void fill(float val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = make_float4(val, val, val, val);
}
}
FLASHINFER_INLINE void load(const float *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(float *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)dst)[i] = ((float4 *)src)[i];
}
}
};
/******************* vec_t type cast *******************/
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = half(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<nv_bfloat16, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] =
__bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<nv_bfloat16, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = nv_bfloat16(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((nv_bfloat162 *)(&dst.data))[i] =
__float22bfloat162_rn(((float2 *)(&src.data))[i]);
}
}
}
#ifdef FLASHINFER_USE_FP8
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e4m3 *)(&dst.data))[i] =
__nv_fp8x4_e4m3(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e5m2(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e5m2 *)(&dst.data))[i] =
__nv_fp8x4_e5m2(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
#endif // FLASHINFER_USE_FP8
#endif // VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cstdint>
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}
...@@ -51,10 +51,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -51,10 +51,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantization ops // Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
...@@ -74,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -74,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"gather_cached_kv", "gather_cached_kv",
&gather_cached_kv, &gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors"); "Gather key and value from the cache into contiguous QKV tensors");
cache_ops.def(
"convert_fp8_e5m2",
&convert_fp8_e5m2,
"Convert the key and value cache to fp8_e5m2 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
...@@ -81,4 +90,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -81,4 +90,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute", "get_device_attribute",
&get_device_attribute, &get_device_attribute,
"Gets the specified device attribute."); "Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
} }
...@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in ...@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
#endif #endif
} }
__global__ void __launch_bounds__(64) dequantize_weights(
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half* B_shared_ptr2 = B_shared;
half B_shared_warp[32];
int OC = 512;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N;
half* C_ptr2 = C + index1;
int index2 = col + row * N;
int* B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N;
int* zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half* scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
int j=0;
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
for (int i=0; i<8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}
} // namespace awq } // namespace awq
} // namespace vllm } // namespace vllm
torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy)
{
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);
int x_thread = thx;
int y_thread = thy;
int x_blocks = 1;
int y_blocks = 1;
if (thx==0) {
x_thread = qout_c;
}
if (thy==0) {
y_thread = in_c;
}
if (thx==0 && thy==0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
kernel, scaling_factors, zeros, de_kernel, G);
return _de_kernel;
}
// in_feats: M, IC [float16] // in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16] // scaling_factors: IC // G, OC [float16]
......
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh"
#pragma once
namespace vllm {
#ifdef ENABLE_FP8_E5M2
namespace fp8_e5m2_unscaled {
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
// fp8 -> half
template<>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
return res.x;
}
// fp8x2 -> half2
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template<>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template<>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template<>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
{
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template<>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template<>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template<>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template<>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// bf16 -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
#endif
}
// float -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// fp8x4 -> float4
template<>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
template<>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
template<>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
__nv_bfloat162 b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
bf16_4_t b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
bf16_8_t b;
from_float(b, a);
return b;
}
} // namespace fp8_e5m2_unscaled
#endif // ENABLE_FP8_E5M2
} // namespace vllm
...@@ -9,11 +9,15 @@ ...@@ -9,11 +9,15 @@
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import os
import sys
from sphinx.ext import autodoc
import logging
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
logger = logging.getLogger(__name__)
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
...@@ -21,7 +25,6 @@ project = 'vLLM' ...@@ -21,7 +25,6 @@ project = 'vLLM'
copyright = '2023, vLLM Team' copyright = '2023, vLLM Team'
author = 'the vLLM Team' author = 'the vLLM Team'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
...@@ -32,6 +35,8 @@ extensions = [ ...@@ -32,6 +35,8 @@ extensions = [
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx.ext.intersphinx", "sphinx.ext.intersphinx",
"sphinx_copybutton", "sphinx_copybutton",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
...@@ -55,7 +60,6 @@ html_title = project ...@@ -55,7 +60,6 @@ html_title = project
html_theme = 'sphinx_book_theme' html_theme = 'sphinx_book_theme'
html_logo = 'assets/logos/vllm-logo-text-light.png' html_logo = 'assets/logos/vllm-logo-text-light.png'
html_theme_options = { html_theme_options = {
'logo_only': True,
'path_to_docs': 'docs/source', 'path_to_docs': 'docs/source',
'repository_url': 'https://github.com/vllm-project/vllm', 'repository_url': 'https://github.com/vllm-project/vllm',
'use_repository_button': True, 'use_repository_button': True,
...@@ -64,4 +68,29 @@ html_theme_options = { ...@@ -64,4 +68,29 @@ html_theme_options = {
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] # html_static_path = ['_static']
# Mock out external dependencies here.
autodoc_mock_imports = [
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
"vllm.cuda_utils", "vllm._C"
]
for mock_target in autodoc_mock_imports:
if mock_target in sys.modules:
logger.info(
f"Potentially problematic mock target ({mock_target}) found; "
"autodoc_mock_imports cannot mock modules that have already "
"been loaded into sys.modules when the sphinx build starts.")
class MockedClassDocumenter(autodoc.ClassDocumenter):
"""Remove note about base class when a class is derived from object."""
def add_line(self, line: str, source: str, *lineno: int) -> None:
if line == " Bases: :py:class:`object`":
return
super().add_line(line, source, *lineno)
autodoc.ClassDocumenter = MockedClassDocumenter
AsyncLLMEngine
=================================
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
:members: generate, abort
:show-inheritance:
vLLM Engine
=================================
.. automodule:: vllm.engine
.. currentmodule:: vllm.engine
.. toctree::
:maxdepth: 2
:caption: Engines
llm_engine
async_llm_engine
LLMEngine
=================================
.. autoclass:: vllm.engine.llm_engine.LLMEngine
:members: add_request, abort_request, step, _init_cache
:show-inheritance:
\ No newline at end of file
...@@ -11,10 +11,10 @@ Requirements ...@@ -11,10 +11,10 @@ Requirements
------------ ------------
* OS: Linux * OS: Linux
* Python: 3.8 -- 3.11 (Verified on 3.10) * Python: 3.8 -- 3.11
* GPU: MI200s * GPU: MI200s (gfx90a), MI300 (gfx942)
* Pytorch 2.0.1/2.1.1/2.2 * Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7 * ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
Installation options: Installation options:
...@@ -27,6 +27,8 @@ Installation options: ...@@ -27,6 +27,8 @@ Installation options:
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image (Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
This option is for ROCm 5.7 only:
.. code-block:: console .. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
...@@ -50,6 +52,9 @@ Option 2: Build from source ...@@ -50,6 +52,9 @@ Option 2: Build from source
You can build and install vLLM from source: You can build and install vLLM from source:
Below instruction is for ROCm 5.7 only.
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
0. Install prerequisites (skip if you are already in an environment/docker with the following installed): 0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
...@@ -95,6 +100,23 @@ You can build and install vLLM from source: ...@@ -95,6 +100,23 @@ You can build and install vLLM from source:
Build a docker image from `Dockerfile.rocm`, and launch a docker container. Build a docker image from `Dockerfile.rocm`, and launch a docker container.
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
For example, to build docker image for vllm on ROCm 5.7, you can run:
.. code-block:: console
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
-f Dockerfile.rocm -t vllm-rocm .
To build vllm on ROCm 6.0, you can use the default:
.. code-block:: console .. code-block:: console
$ docker build -f Dockerfile.rocm -t vllm-rocm . $ docker build -f Dockerfile.rocm -t vllm-rocm .
...@@ -142,3 +164,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from ...@@ -142,3 +164,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
$ cd vllm $ cd vllm
$ pip install -U -r requirements-rocm.txt $ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes. $ python setup.py install # This may take 5-10 minutes.
.. note::
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment