Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
#include "reduction_utils.cuh"
#include <array>
#include "utils.cuh"
#include "activation_kernels_impl.cuh"
#include <cuda_fp16.h>
template<typename T>
__global__ void add_kernel(T *a, T *b, T *c, size_t length) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < length) {
c[i] = a[i] + b[i];
}
}
template<typename T, int unroll>
struct alignas(sizeof(T) * unroll) Tvec {
T data[unroll];
};
template<typename T, int unroll>
__global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_scale, int mod_bias) {
int thread = threadIdx.x + blockIdx.x * blockDim.x;
int i = thread * unroll;
int i_scale = i % mod_scale;
int i_bias = i % mod_bias;
if (i >= length) {
return;
}
using Tvec = ::Tvec<T, unroll>;
Tvec rx = *reinterpret_cast<Tvec *>(&x[i]);
Tvec rscale = *reinterpret_cast<Tvec *>(&scale[i_scale]);
Tvec rbias = *reinterpret_cast<Tvec *>(&bias[i_bias]);
#pragma unroll
for (int k = 0; k < unroll; k++) {
T tmp = rx.data[k] * rscale.data[k] + rbias.data[k];
if constexpr (std::is_same_v<T, half>) {
tmp = __hmin(tmp, (half)65504);
tmp = __hmax(tmp, (half)-65504);
}
rx.data[k] = tmp;
}
*reinterpret_cast<Tvec *>(&x[i]) = rx;
// #pragma unroll
// for (int k = 0; k < unroll; k++) {
// // assert(i < length);
// x[i] = x[i] * scale[i_scale] + bias[i_bias];
// i++;
// i_scale++;
// i_bias++;
// // assert(i_scale < mod_scale);
// // assert(i_bias < mod_bias);
// }
}
template<typename T, size_t N>
__global__ void split_mod_kernel(T *input, std::array<T *, N> output, size_t length) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i * N < length) {
#pragma unroll
for (int k = 0; k < N; k++) {
output[k][i] = input[i * N + k];
}
}
}
template<typename T>
__global__ void EmbeddingKernel(int32_t *__restrict__ input_id, T *__restrict__ output, T *__restrict__ lookup, int embed_dim) {
int i = blockIdx.x;
int32_t token_id = input_id[i];
T *output_sample_ptr = output + i * embed_dim;
T *target_embed = lookup + token_id * embed_dim;
for (int j = threadIdx.x; j < embed_dim; j += blockDim.x) {
output_sample_ptr[j] = target_embed[j];
}
}
template<typename T>
__global__ void argmax_sample_kernel(T *input, int32_t *output, int hidden_dim) {
float maxValue = -1e20;
int argmax = 0;
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
float data = (float)input[blockIdx.x * hidden_dim + i];
if (data > maxValue) {
maxValue = data;
argmax = i;
}
}
// blockAllReduceMax seems to be broken when T=half
float maxValueBlock = vllm::blockAllReduceMax(maxValue);
if (maxValue == maxValueBlock) {
output[blockIdx.x] = argmax;
}
}
template<typename T>
__global__ void splitqkv_kernel(T *qkv, T *q, T *k, T *v, int q_size, int kv_size) {
int qkv_size = q_size + 2 * kv_size;
for (int i = threadIdx.x; i < qkv_size; i += blockDim.x) {
T data = qkv[blockIdx.x * qkv_size + i];
if (i < q_size) {
q[blockIdx.x * q_size + i] = data;
} else if (i < q_size + kv_size) {
k[blockIdx.x * kv_size + i - q_size] = data;
} else {
v[blockIdx.x * kv_size + i - q_size - kv_size] = data;
}
}
}
template <typename T, int unroll>
__global__ void quant_kernel_static(const T * input, int8_t * output, T scale, size_t length) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
if (i >= length) {
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
float fscale = 1.0f / (float)scale;
#pragma unroll
for (int k = 0; k < unroll; k++) {
routput.data[k] = float_to_int8_rn(((float)rinput.data[k]) * fscale);
}
*reinterpret_cast<I8vec *>(&output[i]) = routput;
}
template <typename T, int unroll>
__global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output, T scale, size_t length) {
int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
if (i >= length) {
return;
}
using Tvec = ::Tvec<T, unroll>;
using I8vec = ::Tvec<int8_t, unroll>;
Tvec rinput = *reinterpret_cast<const Tvec *>(&input[i]);
I8vec routput;
float fscale = 1.0f / (float)scale;
#pragma unroll
for (int k = 0; k < unroll; k++) {
routput.data[k] = float_to_int8_rn(((float)vllm::gelu_new_kernel(rinput.data[k])) * fscale);
}
*reinterpret_cast<I8vec *>(&output[i]) = routput;
}
#include <cstdio>
// input: [..., N]
// output: [..., K] of index in reverse order
template<typename T, int K>
__global__
void topk_kernel(const T *input, int *output, int N, int strideInput, int numRows) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int offset = row * strideInput;
if (row >= numRows) {
return;
}
T val[K];
int16_t idx[K];
#pragma unroll
for (int i = 0; i < K; i++) {
val[i] = input[offset + i];
idx[i] = i;
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// for (int i = 0; i < K; i++) {
// printf("%d ", idx[i]);
// }
// printf("\n");
// }
for (int i = K; i < N; i++) {
T newval = input[offset + i];
T minval = val[0];
int minpos = 0;
#pragma unroll
for (int j = 1; j < K; j++) {
if (val[j] < minval) {
minval = val[j];
minpos = j;
}
}
if (newval >= minval) {
#pragma unroll
for (int j = 0; j < K; j++) {
if (j >= minpos) {
val[j] = val[j + 1];
idx[j] = idx[j + 1];
}
}
val[K - 1] = newval;
idx[K - 1] = i;
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// for (int i = 0; i < K; i++) {
// printf("%d ", idx[i]);
// }
// printf("\n");
// }
}
for (int i = 0; i < K; i++) {
output[row * K + i] = idx[K - i - 1];
}
}
\ No newline at end of file
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#define FINAL_MASK 0xffffffff
namespace vllm {
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T) (0.0f);
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockAllReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val)
{
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0)
{
#pragma unroll
for (int i = 0; i < NUM; i++)
{
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++)
{
val[i] = is_mask ? shared[i][lane] : (T) (0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T) 0.0f;
}
template<typename T>
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockAllReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
} // namespace vllm
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include <cuda_fp16.h>
template<typename T> struct num_elems;
template <> struct num_elems<float> { static constexpr int value = 1; };
template <> struct num_elems<float2> { static constexpr int value = 2; };
template <> struct num_elems<float4> { static constexpr int value = 4; };
template <> struct num_elems<half> { static constexpr int value = 1; };
template <> struct num_elems<half2> { static constexpr int value = 2; };
#ifdef ENABLE_BF16
template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; };
template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; };
#endif
#ifdef ENABLE_FP8
template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; };
template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; };
#endif
template<typename T, int num> struct packed_as;
template<typename T> struct packed_as<T, 1> { using type = T; };
template<> struct packed_as<half, 2> { using type = half2; };
template<> struct packed_as<float, 2> { using type = float2; };
template<> struct packed_as<int8_t, 2> { using type = int16_t; };
template<> struct packed_as<int32_t, 2> { using type = int2; };
template<> struct packed_as<half2, 1> { using type = half; };
template<> struct packed_as<float2, 1> { using type = float; };
#ifdef ENABLE_BF16
template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; };
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; };
#endif
#ifdef ENABLE_FP8
template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; };
template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; };
template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; };
template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; };
#endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}
template<typename T>
inline __device__ T ldg(const T* val) {
return __ldg(val);
}
#if ENABLE_BF16
#define bf1622float2 __bfloat1622float2
#define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
#endif
#if ENABLE_BF16
template<>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template<>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return bf162bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#endif
}
#endif
// Binary maximum: compute the max of two scalar types
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif // ENABLE_FP16
\ No newline at end of file
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"
LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device) :
hidden_size(hidden_size), eps(eps)
{
if (elementwise_affine) {
weight = Tensor::allocate({hidden_size}, dtype, device);
bias = Tensor::allocate({hidden_size}, dtype, device);
}
registerParams
(weight, "weight")
(bias, "bias")
;
}
Tensor LayerNorm::forward(Tensor x) {
Tensor out = Tensor::empty(x.shape, x.scalar_type(), x.device());
layernorm_general(out, x, this->weight, this->bias, this->eps);
return out;
}
Tensor RMSNorm::forward(Tensor x) {
Tensor out = Tensor::empty(x.shape, use_quant ? Tensor::INT8 : x.scalar_type(), x.device());
rms_norm(out, x, this->weight, this->variance_epsilon, this->use_quant);
return out;
}
void RMSNormGeneral::forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
rms_norm_general_fuse_sum(quantized_hidden_states_buffer, x, this->weight, quantized_sum_buffer, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
}
void RMSNormGeneral::forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
rms_norm_general(quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
}
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Module.h"
class LayerNorm : public Module {
public:
LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor x);
public:
const int hidden_size;
const float eps;
private:
Tensor weight;
Tensor bias;
};
class RMSNorm : public Module {
public:
RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device) :
use_quant(use_quant), variance_epsilon(eps)
{
weight = Tensor::allocate({hidden_size}, dtype, device);
registerParams(weight, "weight");
}
Tensor forward(Tensor x);
public:
const bool use_quant;
const float variance_epsilon;
Tensor weight;
};
class RMSNormGeneral {
friend class LlamaDecoderLayer;
public:
RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device)
: act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps)
{
this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device);
}
void forward(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
if (act_sum) {
forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
} else {
forward_wo_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
}
}
private:
void forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
void forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
private:
const bool act_sum;
const bool use_per_token_quant;
const float variance_epsilon;
Tensor weight;
};
\ No newline at end of file
#pragma once
#include "common.h"
#include "Tensor.h"
namespace pytorch_compat {
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
assert (cond);
}
template<typename T>
inline void C10_CUDA_CHECK(T ret) {
return checkCUDA(ret);
}
namespace at {
using ::Tensor;
constexpr auto kFloat32 = Tensor::FP32;
constexpr auto kFloat = Tensor::FP32;
constexpr auto kFloat16 = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32 = Tensor::INT32;
constexpr auto kInt64 = Tensor::INT64;
struct Generator {
Generator() { throw std::runtime_error("Not implemented"); }
std::mutex mutex_;
};
namespace cuda {
using ::getCurrentDeviceProperties;
struct StreamWrapper {
cudaStream_t st;
cudaStream_t stream() const { return st; }
};
inline StreamWrapper getCurrentCUDAStream() {
return StreamWrapper(::getCurrentCUDAStream());
}
struct CUDAGuard {
int dev;
};
namespace detail {
inline Generator getDefaultCUDAGenerator() {
return Generator();
}
}
}
using CUDAGeneratorImpl = Generator;
template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
throw std::runtime_error("Not implemented");
}
}
namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();
using IntArrayRef = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;
inline Tensor empty_like(const Tensor &tensor) {
return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}
namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
throw std::runtime_error("Not implemented");
}
}
}
namespace indexing {
constexpr int None = 0;
struct Slice {
int a;
int b;
};
}
}
namespace c10 {
using std::optional;
}
}
Subproject commit 1fe946289e897ff3ae0289b7260b3c86634dfdc6
Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5
Subproject commit 63258397761b3dd96dd171e5a5ad5aa915834c35
Subproject commit 8b6b7d878c89e81614d05edca7936de41ccdd2da
Subproject commit 27cb4c76708608465c413f6d0e6b8d99a4d84302
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment