Commit ef6b60e2 authored by Casper Hansen's avatar Casper Hansen
Browse files

New kernels

parent 84fb7e98
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace fastertransformer {
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t; t.x = x; t.y = y; return t;
}
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace fastertransformer
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template<typename T, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 32:
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 48:
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 64:
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 80:
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 160:
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 192:
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 224:
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 256:
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template<typename T>
struct Multihead_attention_params_base {
// The output buffer. Dimensions B x D.
T* out = nullptr;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const T *q = nullptr, *q_bias = nullptr;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const T *k = nullptr, *k_bias = nullptr;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride = 0;
// The batch size.
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
int memory_max_len = 0;
// The number of heads (H).
int num_heads = 0;
// The number of heads for KV cache.
int num_kv_heads = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
bool neox_rotary_style = false;
float rotary_base = 0.0f;
// The maximum length of input sentences.
int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float inv_sqrt_dh = 0.0f;
// Used when we have some input context like gpt
const int* total_padding_tokens = nullptr;
const bool* masked_tokens = nullptr;
const int* prefix_prompt_lengths = nullptr;
int max_prefix_prompt_length = 0;
const T* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
// The slope per head of linear position bias to attention score (H).
const float* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
};
template<typename T, bool CROSS_ATTENTION>
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
// will need it here till if constexpr in c++17
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<typename T>
struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<class T>
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
template<class T>
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
template<typename T>
struct outputCrossAttentionParam {
// max decoder output length
int max_decoder_seq_len = 0;
T* cross_attention_out = nullptr;
bool is_return_cross_attentions = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
This diff is collapsed.
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template<typename T>
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
const cudaStream_t& stream);
template<typename T>
void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
const cudaStream_t& stream);
template<typename T>
struct SATypeConverter {
using Type = T;
};
template<>
struct SATypeConverter<at::Half> {
using Type = uint16_t;
};
template<>
struct SATypeConverter<at::BFloat16> {
using Type = __nv_bfloat16;
};
template <typename T>
void set_params(Masked_multihead_attention_params<T> &params,
const size_t batch_size,
const size_t nheads,
const size_t nheads_kv,
const size_t memory_max_seqlen,
const size_t headdim,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
const bool neox_rotary_style,
const int qkv_batch_stride,
T *q_ptr,
T *k_ptr,
T *v_ptr,
T *k_cache_ptr,
T *v_cache_ptr,
int *length_per_sample,
float *alibi_slopes_ptr,
T *out_ptr) {
// Reset the parameters
memset(&params, 0, sizeof(params));
params.q = q_ptr;
params.k = k_ptr;
params.v = v_ptr;
params.q_bias = nullptr;
params.k_bias = nullptr;
params.v_bias = nullptr;
params.k_cache = k_cache_ptr;
params.v_cache = v_cache_ptr;
params.linear_bias_slopes = alibi_slopes_ptr;
params.out = out_ptr;
params.cache_indir = nullptr;
params.stride = qkv_batch_stride;
params.batch_size = batch_size;
params.beam_width = 1;
params.memory_max_len = memory_max_seqlen;
params.num_heads = nheads;
params.num_kv_heads = nheads_kv;
params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_base = rotary_base;
params.neox_rotary_style = neox_rotary_style;
params.timestep = timestep;
params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
params.total_padding_tokens = nullptr;
params.masked_tokens = nullptr;
params.prefix_prompt_lengths = nullptr;
params.max_prefix_prompt_length = 0;
params.relative_attention_bias = nullptr;
params.relative_attention_bias_stride = 0;
params.cross_attention_out = nullptr;
params.max_decoder_seq_len = 0;
params.is_return_cross_attentions = false;
params.finished = nullptr;
params.memory_length_per_sample = nullptr;
params.length_per_sample = length_per_sample;
}
torch::Tensor single_query_attention(const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
torch::Tensor k_cache,
torch::Tensor v_cache,
c10::optional<const torch::Tensor> length_per_sample_,
c10::optional<const torch::Tensor> alibi_slopes_,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
// neox_rotary_style = not interleaved
const bool neox_rotary_style) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0);
int nheads = q.size(1);
int nheads_kv = v_cache.size(1);
int memory_max_seqlen = v_cache.size(2);
int headdim = v_cache.size(3);
CHECK_SHAPE(q, batch_size, nheads, headdim);
CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
if (length_per_sample_.has_value()) {
auto length_per_sample = length_per_sample_.value();
CHECK_DEVICE(length_per_sample);
CHECK_SHAPE(length_per_sample, batch_size);
CHECK_CONTIGUOUS(length_per_sample);
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
}
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
CHECK_SHAPE(alibi_slopes, nheads);
CHECK_CONTIGUOUS(alibi_slopes);
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
torch::Tensor out = torch::empty_like(q);
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim,
timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()),
reinterpret_cast<DataType*>(v.data_ptr()),
reinterpret_cast<DataType*>(k_cache.data_ptr()),
reinterpret_cast<DataType*>(v_cache.data_ptr()),
length_per_sample_.has_value()
? length_per_sample_.value().data_ptr<int>() : nullptr,
alibi_slopes_.has_value()
? alibi_slopes_.value().data_ptr<float>(): nullptr,
reinterpret_cast<DataType*>(out.data_ptr()));
auto stream = at::cuda::getCurrentCUDAStream();
masked_multihead_attention(params, stream);
});
return out;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
torch::Tensor single_query_attention(const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
torch::Tensor k_cache,
torch::Tensor v_cache,
c10::optional<const torch::Tensor> length_per_sample_,
c10::optional<const torch::Tensor> alibi_slopes_,
const int timestep,
const int rotary_embedding_dim = 0,
const float rotary_base = 10000.0f,
const bool neox_rotary_style=true);
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding", &rotary_embedding, "Apply rotary embedding to query and key");
}
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
}
\ No newline at end of file
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__ __forceinline__ float warp_reduce_sum(float sum) {
#pragma unroll
for(int i = 4; i >= 0; i--){
sum += __shfl_down_sync(0xffffffff, sum, 1<<i);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return sum;
}
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__ void gemv_kernel_g64(
const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs,
const int IC, const int OC){
const int group_size = 64;
float psum = 0;
const int batch_idx = blockIdx.z;
const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
half* outputs = _outputs + batch_idx * OC;
// This is essentially zeros_w.
const int num_groups_packed = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2;
const int weight_w = IC / PACK_FACTOR;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const int zeros_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2;
// consistent with input shape
const int sf_w = make_divisible(make_divisible(IC / group_size, PACK_FACTOR), 2) * 2 * PACK_FACTOR;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for(int packed_group_idx = 0; packed_group_idx < num_groups_packed / 2; packed_group_idx++){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t packed_zeros = *reinterpret_cast<const uint64_t*>(zeros + oc_idx * zeros_w + packed_group_idx * 2);
uint32_t packed_weights[4];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]);
float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF);
int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
const float4* inputs_ptr = inputs + inputs_ptr_delta;
// multiply 32 weights with 32 inputs
#pragma unroll
for (int ic_0 = 0; ic_0 < 4; ic_0++){
// iterate over different uint32_t packed_weights in this loop
uint32_t current_packed_weight = packed_weights[ic_0];
half packed_inputs[PACK_FACTOR];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
*((float4*)packed_inputs) = *(inputs_ptr + ic_0);
#pragma unroll
for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
// iterate over 8 numbers packed within each uint32_t number
float current_single_weight_fp = (float)(current_packed_weight & 0xF);
float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
current_packed_weight = current_packed_weight >> 4;
}
}
}
}
psum = warp_reduce_sum(psum);
if (threadIdx.x == 0) {
outputs[oc_idx] = __float2half(psum);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__ void gemv_kernel_g128(
const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs,
const int IC, const int OC){
const int group_size = 128;
float psum = 0;
const int batch_idx = blockIdx.z;
const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
half* outputs = _outputs + batch_idx * OC;
const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR);
const int weight_w = IC / PACK_FACTOR;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR);
// consistent with input shape
const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx);
uint32_t packed_weights[4];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF);
int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
const float4* inputs_ptr = inputs + inputs_ptr_delta;
// multiply 32 weights with 32 inputs
#pragma unroll
for (int ic_0 = 0; ic_0 < 4; ic_0++){
// iterate over different uint32_t packed_weights in this loop
uint32_t current_packed_weight = packed_weights[ic_0];
half packed_inputs[PACK_FACTOR];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
*((float4*)packed_inputs) = *(inputs_ptr + ic_0);
#pragma unroll
for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
// iterate over 8 numbers packed within each uint32_t number
float current_single_weight_fp = (float)(current_packed_weight & 0xF);
float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
current_packed_weight = current_packed_weight >> 4;
}
}
}
}
psum = warp_reduce_sum(psum);
if (threadIdx.x == 0) {
outputs[oc_idx] = __float2half(psum);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
// int kernel_volume = _out_in_map.size(1);
auto in_feats = reinterpret_cast<float4*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr<int>());
auto zeros = reinterpret_cast<uint32_t*>(_zeros.data_ptr<int>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
// kernel is [OC, IC]
at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
int blockDim_z = num_out_feats;
dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
dim3 num_threads(32, 4);
if (group_size == 64)
{
gemv_kernel_g64<<<num_blocks, num_threads>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
num_in_channels, num_out_channels
);
}
else if (group_size == 128)
{
gemv_kernel_g128<<<num_blocks, num_threads>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
num_in_channels, num_out_channels
);
}
return _out_feats;
;}
#pragma once
#include <torch/extension.h>
torch::Tensor gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size);
......@@ -47,11 +47,24 @@ requirements = [
"torchvision"
]
include_dirs = []
def get_include_dirs():
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
this_dir = os.path.dirname(os.path.abspath(__file__))
include_dirs.append(this_dir)
return include_dirs
def get_generator_flag():
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
return generator_flag
def check_dependencies():
if CUDA_HOME is None:
......@@ -77,6 +90,8 @@ def get_compute_capabilities():
return capability_flags
check_dependencies()
include_dirs = get_include_dirs()
generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities()
if os.name == "nt":
......@@ -86,8 +101,21 @@ if os.name == "nt":
}
else:
extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"] + arch_flags
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
"nvcc": [
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
] + arch_flags + generator_flags
}
extensions = [
......@@ -97,7 +125,10 @@ extensions = [
"awq_cuda/pybind.cpp",
"awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
"awq_cuda/position_embedding/pos_encoding_kernels.cu",
"awq_cuda/quantization/gemv_cuda.cu",
"awq_cuda/attention/ft_attention.cpp",
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args
)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment