Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "nv_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__
template <typename float_t, size_t vec_size>
struct vec_t {
FLASHINFER_INLINE float_t &operator[](size_t i);
FLASHINFER_INLINE const float_t &operator[](size_t i) const;
FLASHINFER_INLINE void fill(float_t val);
FLASHINFER_INLINE void load(const float_t *ptr);
FLASHINFER_INLINE void store(float_t *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src);
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr);
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const;
FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src);
};
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<src_float_t, vec_size> &src,
vec_t<tgt_float_t, vec_size> &dst) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
dst[i] = tgt_float_t(src[i]);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr,
vec_t<tgt_float_t, vec_size> &dst) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
dst.load(src_ptr);
} else {
vec_t<src_float_t, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
FLASHINFER_INLINE void cast_store_impl(const vec_t<src_float_t, vec_size> &src,
tgt_float_t *dst_ptr) {
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
src.store(dst_ptr);
} else {
vec_t<tgt_float_t, vec_size> tmp;
tmp.cast_from(src);
tmp.store(dst_ptr);
}
}
#ifdef FLASHINFER_USE_FP8
/******************* vec_t<__nv_fp8_e4m3> *******************/
// __nv_fp8_e4m3 x 1
template <>
struct vec_t<__nv_fp8_e4m3, 1> {
__nv_fp8_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(
__nv_fp8_e4m3 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*dst = *src;
}
// __nv_fp8_e4m3 x 2
template <>
struct vec_t<__nv_fp8_e4m3, 2> {
__nv_fp8x2_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x2_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x2_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src);
}
// __nv_fp8_e4m3 x 4
template <>
struct vec_t<__nv_fp8_e4m3, 4> {
__nv_fp8x4_e4m3 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) {
data = *((__nv_fp8x4_e4m3 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(
__nv_fp8_e4m3 *ptr) const {
*((__nv_fp8x4_e4m3 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src);
}
// __nv_fp8_e4m3 x 8
template <>
struct vec_t<__nv_fp8_e4m3, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(
__nv_fp8_e4m3 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
__nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) {
*((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src);
}
// __nv_fp8_e4m3 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e4m3, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) {
return ((__nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const {
return ((const __nv_fp8_e4m3 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst,
const __nv_fp8_e4m3 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<__nv_fp8_e5m2> *******************/
// __nv_fp8_e5m2 x 1
template <>
struct vec_t<__nv_fp8_e5m2, 1> {
__nv_fp8_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(
__nv_fp8_e5m2 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*dst = *src;
}
// __nv_fp8_e5m2 x 2
template <>
struct vec_t<__nv_fp8_e5m2, 2> {
__nv_fp8x2_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
data.__x =
(__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x2_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x2_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src);
}
// __nv_fp8_e5m2 x 4
template <>
struct vec_t<__nv_fp8_e5m2, 4> {
__nv_fp8x4_e5m2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) {
data = *((__nv_fp8x4_e5m2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(
__nv_fp8_e5m2 *ptr) const {
*((__nv_fp8x4_e5m2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src);
}
// __nv_fp8_e5m2 x 8
template <>
struct vec_t<__nv_fp8_e5m2, 8> {
uint2 data;
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)(&data))[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr);
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 8> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src);
};
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) |
__nv_fp8x4_storage_t(val.__x);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(
__nv_fp8_e5m2 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
__nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) {
*((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src);
}
// __nv_fp8_e5m2 x 16 or more
template <size_t vec_size>
struct vec_t<__nv_fp8_e5m2, vec_size> {
uint4 data[vec_size / 16];
FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) {
return ((__nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const {
return ((const __nv_fp8_e5m2 *)data)[i];
}
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x =
(__nv_fp8x4_storage_t(val.__x) << 24) |
(__nv_fp8x4_storage_t(val.__x) << 16) |
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
}
}
FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst,
const __nv_fp8_e5m2 *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 16; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
#endif
/******************* vec_t<half> *******************/
// half x 1
template <>
struct vec_t<half, 1> {
half data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
FLASHINFER_INLINE void vec_t<half, 1>::load(const half *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<half, 1>::store(half *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<half, 1>::memcpy(half *dst, const half *src) {
*dst = *src;
}
// half x 2
template <>
struct vec_t<half, 2> {
half2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 2>::fill(half val) {
data = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 2>::load(const half *ptr) {
data = *((half2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 2>::store(half *ptr) const {
*((half2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 2>::memcpy(half *dst, const half *src) {
*((half2 *)dst) = *((half2 *)src);
}
// half x 4
template <>
struct vec_t<half, 4> {
uint2 data;
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)(&data))[i];
}
FLASHINFER_INLINE void fill(half val);
FLASHINFER_INLINE void load(const half *ptr);
FLASHINFER_INLINE void store(half *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src);
};
FLASHINFER_INLINE void vec_t<half, 4>::fill(half val) {
*(half2 *)(&data.x) = make_half2(val, val);
*(half2 *)(&data.y) = make_half2(val, val);
}
FLASHINFER_INLINE void vec_t<half, 4>::load(const half *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<half, 4>::store(half *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<half, 4>::memcpy(half *dst, const half *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// half x 8 or more
template <size_t vec_size>
struct vec_t<half, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; }
FLASHINFER_INLINE const half &operator[](size_t i) const {
return ((const half *)data)[i];
}
FLASHINFER_INLINE void fill(half val) {
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
*(half2 *)(&(data[i].x)) = make_half2(val, val);
*(half2 *)(&(data[i].y)) = make_half2(val, val);
*(half2 *)(&(data[i].z)) = make_half2(val, val);
*(half2 *)(&(data[i].w)) = make_half2(val, val);
}
}
FLASHINFER_INLINE void load(const half *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(half *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(half *dst, const half *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<nv_bfloat16> *******************/
// nv_bfloat16 x 1
template <>
struct vec_t<nv_bfloat16, 1> {
nv_bfloat16 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
data = val;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16 *ptr) {
data = *ptr;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16 *ptr) const {
*ptr = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*dst = *src;
}
// nv_bfloat16 x 2
template <>
struct vec_t<nv_bfloat16, 2> {
nv_bfloat162 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
data = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16 *ptr) {
data = *((nv_bfloat162 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16 *ptr) const {
*((nv_bfloat162 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src);
}
// nv_bfloat16 x 4
template <>
struct vec_t<nv_bfloat16, 4> {
uint2 data;
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)(&data))[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val);
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr);
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 4> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src);
};
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
*(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16 *ptr) {
data = *((uint2 *)ptr);
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16 *ptr) const {
*((uint2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
*((uint2 *)dst) = *((uint2 *)src);
}
// nv_bfloat16 x 8 or more
template <size_t vec_size>
struct vec_t<nv_bfloat16, vec_size> {
uint4 data[vec_size / 8];
FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) {
return ((nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const {
return ((const nv_bfloat16 *)data)[i];
}
FLASHINFER_INLINE void fill(nv_bfloat16 val) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
*(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val);
*(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val);
}
}
FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
data[i] = ((uint4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst,
const nv_bfloat16 *src) {
#pragma unoll
for (size_t i = 0; i < vec_size / 8; ++i) {
((uint4 *)dst)[i] = ((uint4 *)src)[i];
}
}
};
/******************* vec_t<float> *******************/
// float x 1
template <>
struct vec_t<float, 1> {
float data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
FLASHINFER_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
FLASHINFER_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
FLASHINFER_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
*dst = *src;
}
// float x 2
template <>
struct vec_t<float, 2> {
float2 data;
FLASHINFER_INLINE float &operator[](size_t i) {
return ((float *)(&data))[i];
}
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
FLASHINFER_INLINE void fill(float val);
FLASHINFER_INLINE void load(const float *ptr);
FLASHINFER_INLINE void store(float *ptr) const;
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src);
};
FLASHINFER_INLINE void vec_t<float, 2>::fill(float val) {
data = make_float2(val, val);
}
FLASHINFER_INLINE void vec_t<float, 2>::load(const float *ptr) {
data = *((float2 *)ptr);
}
FLASHINFER_INLINE void vec_t<float, 2>::store(float *ptr) const {
*((float2 *)ptr) = data;
}
FLASHINFER_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
// float x 4 or more
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
FLASHINFER_INLINE const float &operator[](size_t i) const {
return ((const float *)(data))[i];
}
FLASHINFER_INLINE void fill(float val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = make_float4(val, val, val, val);
}
}
FLASHINFER_INLINE void load(const float *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4 *)ptr)[i];
}
}
FLASHINFER_INLINE void store(float *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)ptr)[i] = data[i];
}
}
template <typename T>
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(src, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_load(const T *ptr) {
cast_load_impl(ptr, *this);
}
template <typename T>
FLASHINFER_INLINE void cast_store(T *ptr) const {
cast_store_impl(*this, ptr);
}
FLASHINFER_INLINE static void memcpy(float *dst, const float *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)dst)[i] = ((float4 *)src)[i];
}
}
};
/******************* vec_t type cast *******************/
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = half(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<nv_bfloat16, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((float2 *)(&dst.data))[i] =
__bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<nv_bfloat16, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = nv_bfloat16(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((nv_bfloat162 *)(&dst.data))[i] =
__float22bfloat162_rn(((float2 *)(&src.data))[i]);
}
}
}
#ifdef FLASHINFER_USE_FP8
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e4m3 *)(&dst.data))[i] =
__nv_fp8x4_e4m3(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e4m3, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<float, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else if constexpr (vec_size == 2) {
*(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src,
vec_t<half, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = float(src.data);
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 2; ++i) {
((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<float, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e5m2(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((__nv_fp8x4_e5m2 *)(&dst.data))[i] =
__nv_fp8x4_e5m2(((float4 *)(&src.data))[i]);
}
}
}
template <size_t vec_size>
FLASHINFER_INLINE void cast_from_impl(const vec_t<half, vec_size> &src,
vec_t<__nv_fp8_e5m2, vec_size> &dst) {
if constexpr (vec_size == 1) {
dst.data = __nv_fp8_e4m3(src.data);
} else if constexpr (vec_size == 2) {
*(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data));
} else {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2(
((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]);
}
}
}
#endif // FLASHINFER_USE_FP8
#endif // VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cstdint>
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}
...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul", "silu_and_mul",
&silu_and_mul, &silu_and_mul,
"Activation function used in SwiGLU."); "Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU.");
ops.def( ops.def(
"gelu_new", "gelu_new",
&gelu_new, &gelu_new,
...@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
...@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&reshape_and_cache, &reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
cache_ops.def( cache_ops.def(
"gather_cached_kv", "convert_fp8_e5m2",
&gather_cached_kv, &convert_fp8_e5m2,
"Gather key and value from the cache into contiguous QKV tensors"); "Convert the key and value cache to fp8_e5m2 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
...@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute", "get_device_attribute",
&get_device_attribute, &get_device_attribute,
"Gets the specified device attribute."); "Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment