Commit e90433a0 authored by Casper's avatar Casper
Browse files

Initial commit

parent 5440c0aa
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......
# AutoAWQ_kernels
\ No newline at end of file
# AutoAWQ Kernels
AutoAWQ Kernels is a new package that is split up from the [main repository](https://github.com/casper-hansen/AutoAWQ) in order to avoid compilation times.
## Requirements
- Windows: Must use WSL2.
- GPU: Must be compute capability 7.5 or higher.
- CUDA Toolkit: Must be 11.8 or higher.
\ No newline at end of file
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace fastertransformer {
#ifdef ENABLE_BF16
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union { int8_t int8[2]; int16_t int16; };
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
#else
return __hadd2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
#else
return __hadd(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
#else
return __hsub2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
#else
return __hsub(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#else
return __hmul2(x, y);
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
#else
return __hmul(x, y);
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh, fyl, fyh, fzl, fzh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
fzl = __low2float(z);
fzh = __high2float(z);
return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
#else
return __hfma2(x, y, z);
#endif
}
inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
#else
return __hfma(x, y, z);
#endif
}
inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fxl, fxh;
fxl = __low2float(x);
fxh = __high2float(x);;
return __floats2bfloat162_rn(expf(fxl), expf(fxh));
#else
return h2exp(x);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
__nv_bfloat162 t; t.x = x; t.y = y; return t;
}
#endif
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
#else
return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
#endif
}
inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
#else
return a + b + c;
#endif
}
inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
#else
return a * b * c;
#endif
}
inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
fal = __low2float(a);
fah = __high2float(a);
fbl = __low2float(b);
fbh = __high2float(b);
fcl = __low2float(c);
fch = __high2float(c);
fdl = __low2float(d);
fdh = __high2float(d);
return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
#else
return a * b * c + d;
#endif
}
#endif // ENABLE_BF16
} // namespace fastertransformer
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template<typename T, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 32:
mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 48:
mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 64:
mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 80:
mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 160:
mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 192:
mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 224:
mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 256:
mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template<typename T>
struct Multihead_attention_params_base {
// The output buffer. Dimensions B x D.
T* out = nullptr;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const T *q = nullptr, *q_bias = nullptr;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const T *k = nullptr, *k_bias = nullptr;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride = 0;
// The batch size.
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
int memory_max_len = 0;
// The number of heads (H).
int num_heads = 0;
// The number of heads for KV cache.
int num_kv_heads = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
bool neox_rotary_style = false;
float rotary_base = 0.0f;
// The maximum length of input sentences.
int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float inv_sqrt_dh = 0.0f;
// Used when we have some input context like gpt
const int* total_padding_tokens = nullptr;
const bool* masked_tokens = nullptr;
const int* prefix_prompt_lengths = nullptr;
int max_prefix_prompt_length = 0;
const T* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
// The slope per head of linear position bias to attention score (H).
const float* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
};
template<typename T, bool CROSS_ATTENTION>
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
// will need it here till if constexpr in c++17
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<typename T>
struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
// output cross attentions
float* cross_attention_out = nullptr;
int max_decoder_seq_len = 0;
bool is_return_cross_attentions = false;
// allows to exist attention eary
bool* finished = nullptr;
// required in case of cross attention
int* memory_length_per_sample = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};
template<class T>
using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
template<class T>
using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
template<typename T>
struct outputCrossAttentionParam {
// max decoder output length
int max_decoder_seq_len = 0;
T* cross_attention_out = nullptr;
bool is_return_cross_attentions = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <assert.h>
#include <float.h>
#include <type_traits>
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 64, 128 and 256 threads per block.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
// values for x are chosen to create chunks of 16 bytes.
//
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is much simpler
// as it is [B, H, L, Dh].
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh>
struct Qk_vec_ {
};
template<>
struct Qk_vec_<float, 32> {
using Type = float;
};
template<>
struct Qk_vec_<float, 64> {
using Type = float2;
};
template<>
struct Qk_vec_<float, 128> {
using Type = float4;
};
template<>
struct Qk_vec_<float, 256> {
using Type = float4;
};
template<>
struct Qk_vec_<uint16_t, 32> {
using Type = uint32_t;
};
template<>
struct Qk_vec_<uint16_t, 64> {
using Type = uint32_t;
};
template<>
struct Qk_vec_<uint16_t, 128> {
using Type = uint2;
};
template<>
struct Qk_vec_<uint16_t, 256> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct Qk_vec_<__nv_bfloat16, 32> {
using Type = __nv_bfloat162;
};
template<>
struct Qk_vec_<__nv_bfloat16, 64> {
using Type = __nv_bfloat162;
};
template<>
struct Qk_vec_<__nv_bfloat16, 128> {
using Type = bf16_4_t;
};
template<>
struct Qk_vec_<__nv_bfloat16, 256> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct K_vec_ {
};
template<>
struct K_vec_<float, 4> {
using Type = float;
};
template<>
struct K_vec_<float, 2> {
using Type = float2;
};
template<>
struct K_vec_<float, 1> {
using Type = float4;
};
template<>
struct K_vec_<uint16_t, 4> {
using Type = uint32_t;
};
template<>
struct K_vec_<uint16_t, 2> {
using Type = uint2;
};
template<>
struct K_vec_<uint16_t, 1> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct K_vec_<__nv_bfloat16, 4> {
using Type = __nv_bfloat162;
};
template<>
struct K_vec_<__nv_bfloat16, 2> {
using Type = bf16_4_t;
};
template<>
struct K_vec_<__nv_bfloat16, 1> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE>
struct V_vec_ {
};
template<>
struct V_vec_<float, 1> {
using Type = float;
};
template<>
struct V_vec_<float, 2> {
using Type = float2;
};
template<>
struct V_vec_<float, 4> {
using Type = float4;
};
template<>
struct V_vec_<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct V_vec_<uint16_t, 4> {
using Type = uint2;
};
template<>
struct V_vec_<uint16_t, 8> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct V_vec_<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template<>
struct V_vec_<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template<>
struct V_vec_<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct Qk_vec_acum_fp32_ {
};
template<>
struct Qk_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<float4> {
using Type = float4;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template<>
struct Qk_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct K_vec_acum_fp32_ {
};
template<>
struct K_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct K_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<float4> {
using Type = float4;
};
template<>
struct K_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct K_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct K_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct K_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct K_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template<typename T>
struct V_vec_acum_fp32_ {
};
template<>
struct V_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct V_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<float4> {
using Type = float4;
};
template<>
struct V_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct V_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
#ifdef ENABLE_BF16
template<>
struct V_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct V_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#endif // ENABLE_BF16
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
#else
using K_vec_acum = K_vec;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct Qk_dot {
template<typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Qk_dot<uint16_t, 4> {
template<int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float& dst, float src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint16_t& dst, float src)
{
dst = float_to_half(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint32_t& dst, float2 src)
{
dst = float2_to_half2(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
{
dst = __float2bfloat16(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst = __float22bfloat162_rn(src);
#else
dst = __floats2bfloat162_rn(src.x, src.y);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint2& dst, Float4_ src)
{
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint2& dst, float4 src)
{
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint4& dst, Float8_ src)
{
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
{
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float2& dst, float2 src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float4& dst, float4 src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float convert_to_float(float4 u)
{
return u.x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float convert_to_float(uint4 u)
{
float2 tmp = half2_to_float2(u.x);
return tmp.x;
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(float u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(float2 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 cast_to_float(float4 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(Float4_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(Float8_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(uint32_t u)
{
return half2_to_float2(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(uint2 u)
{
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(uint4 u)
{
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float float_from_int8(int8_t u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 float_from_int8(int16_t u)
{
union {
int16_t int16;
int8_t int8[2];
};
int16 = u;
return make_float2(int8[0], int8[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 float_from_int8(int32_t u)
{
union {
int32_t int32;
int8_t int8[4];
};
int32 = u;
return make_float4(int8[0], int8[1], int8[2], int8[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
inline __device__ Float8_ float_from_int8(int64_t u)
{
union {
int64_t int64;
int16_t int16[4];
};
int64 = u;
return Float8_ {float_from_int8(int16[0]),
float_from_int8(int16[1]),
float_from_int8(int16[2]),
float_from_int8(int16[3])};
}
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int8_t cast_to_int8(float val)
{
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int32_t cast_to_int8(float4 val)
{
union {
int8_t int8[4];
int32_t int32;
};
int8[0] = cast_to_int8(val.x);
int8[1] = cast_to_int8(val.y);
int8[2] = cast_to_int8(val.z);
int8[3] = cast_to_int8(val.w);
return int32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int64_t cast_to_int8(Float8_ val)
{
union {
int8_t int8[8];
int64_t int64;
};
int8[0] = cast_to_int8(val.x.x);
int8[1] = cast_to_int8(val.x.y);
int8[2] = cast_to_int8(val.y.x);
int8[3] = cast_to_int8(val.y.y);
int8[4] = cast_to_int8(val.z.x);
int8[5] = cast_to_int8(val.z.y);
int8[6] = cast_to_int8(val.w.x);
int8[7] = cast_to_int8(val.w.y);
return int64;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ __host__ T div_up(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, bool DO_CROSS_ATTENTION>
inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
int threads_per_value,
int threads_per_block)
{
// The amount of shared memory needed to store the Q*K^T values in float.
const int max_timesteps = min(params.timestep, params.memory_max_len);
size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
// The extra memory needed if we are not using floats for the final logits.
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(T) != 4) {
// TDOD
logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
}
#endif
// The total size needed during softmax.
size_t softmax_sz = qk_sz + logits_sz;
// The number of partial rows to reduce in the final reduction.
int rows_per_red = threads_per_block / threads_per_value;
// The amount of storage needed to finalize the outputs.
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
size_t transpose_rotary_size = 0;
if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
}
// The max.
return max(max(softmax_sz, red_sz), transpose_rotary_size);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ constexpr uint32_t shfl_mask(int threads)
{
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The type of the inputs. Supported types: float and half.
typename T,
// The hidden dimension per head.
int Dh,
int Dh_MAX,
// The number of threads per key.
int THREADS_PER_KEY,
// The number of threads per value.
int THREADS_PER_VALUE,
// The number of threads in a threadblock.
int THREADS_PER_BLOCK,
bool DO_CROSS_ATTENTION>
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
{
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
// The size of a warp.
constexpr int WARP_SIZE = 32;
// The number of warps in a threadblock.
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern __shared__ char smem_[];
// The shared memory for the Q*K^T values and partial logits in softmax.
float* qk_smem = reinterpret_cast<float*>(smem_);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char* logits_smem_ = smem_;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(T) != 4) {
// TODO - change to tlength
const int max_timesteps = min(params.timestep, params.memory_max_len);
logits_smem_ +=
(DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
}
T* logits_smem = reinterpret_cast<T*>(logits_smem_);
#else
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
#endif
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
T* out_smem = reinterpret_cast<T*>(smem_);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
// A vector of Q or K elements for the current timestep.
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
// Use alignment for safely casting the shared buffers as Qk_vec.
// Shared memory to store Q inputs.
__shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
// This is one of the reasons we should have a separate kernel for cross attention
__shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
// A vector of Q or K elements for the current timestep.
using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
// The number of elements per vector.
constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// We will use block wide reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
// The number of vectors per warp.
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.
// The number of elements in a chunk of 16B (that's the x in the above formula).
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
// The number of K vectors in 16B.
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
// The batch/beam idx
const int bi = blockIdx.y;
if (params.finished != nullptr && params.finished[bi] == true) {
return;
}
// The beam idx
const int beami = bi % params.beam_width;
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;
// The head.
const int num_kv_heads = params.num_kv_heads;
const int kv_rep = (params.num_heads / num_kv_heads);
const int hi = blockIdx.x;
const int hi_kv = hi / kv_rep;
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi;
const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv;
// The thread in the block.
const int tidx = threadIdx.x;
const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
// Every kv_rep threads have the same kv_cache values. So only the first one writes back.
const int write_kv_cache = handle_kv && (hi % kv_rep == 0);
// While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX;
float qk = 0.0F;
// int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const int q_base_offset = bi * params.stride + hi * Dh;
const int k_base_offset = bi * params.stride + hi_kv * Dh;
const int v_base_offset = k_base_offset;
const size_t bi_seq_len_offset = bi * params.memory_max_len;
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
(params.length_per_sample == nullptr) ?
params.timestep :
params.length_per_sample[bi] + params.max_prefix_prompt_length;
const int first_step = max(0, tlength + 1 - params.memory_max_len);
const int tlength_circ = tlength % params.memory_max_len;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;
// The offset in the Q and K buffer also accounts for the batch.
// int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
int v_offset = k_offset;
// The offset in the bias buffer.
// int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
int v_bias_offset = k_bias_offset;
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
// Trigger the loads from the Q and K buffers.
Qk_vec q;
zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto q_scaling = params.qkv_scale_out[0];
const auto q_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
}
else {
q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
}
}
Qk_vec k;
zero(k);
if (DO_CROSS_ATTENTION) {
// The 16B chunk written by the thread.
int co = tidx / QK_VECS_IN_16B;
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B +
tlength * QK_ELTS_IN_16B + ci;
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
*reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
k;
}
else {
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto k_scaling = params.qkv_scale_out[1];
const auto k_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
}
else {
k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
}
}
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec q_bias;
zero(q_bias);
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
q_bias;
Qk_vec k_bias;
zero(k_bias);
if (handle_kv) {
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
k_bias;
}
// Computes the Q/K values with bias.
q = add(q, q_bias);
if (handle_kv) {
k = add(k, k_bias);
}
if (do_ia3 && !is_masked) {
k = mul<Qk_vec, Qk_vec, Qk_vec>(
k,
*reinterpret_cast<const Qk_vec*>(
&params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
}
// Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
}
else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
}
}
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
T* q_smem = reinterpret_cast<T*>(smem_);
T* k_smem = q_smem + params.rotary_embedding_dim;
const int half_rotary_dim = params.rotary_embedding_dim / 2;
const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
assert(half_rotary_dim % QK_VEC_SIZE == 0);
if (do_rotary) {
*reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
if (handle_kv) {
*reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
}
}
__syncthreads();
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
if (do_rotary) {
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
if (handle_kv) {
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
}
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
}
__syncthreads();
if (do_rotary) {
q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
if (handle_kv) {
k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
}
}
__syncthreads();
}
if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
// Store Dh values of k_bias into smem, since will need to add later
// if params.timestep == 0
if (DO_CROSS_ATTENTION && params.timestep == 0) {
*reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
}
// Write the K values to the global memory cache.
//
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
// more loads) + the stores are really "write and forget" since we won't need the ack before
// the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread.
int co = tidx / QK_VECS_IN_16B;
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B +
tlength_circ * QK_ELTS_IN_16B + ci;
if (write_kv_cache) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
}
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
#else
using Qk_vec_acum = Qk_vec;
#endif
qk = dot<Qk_vec_acum, Qk_vec>(q, k);
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
if (tidx == 0) {
// Normalize qk.
qk *= params.inv_sqrt_dh;
if (params.relative_attention_bias != nullptr) {
// TODO (Haotian): check whether we should replace hi with hi_kv,
// although params.relative_attention_bias is usually not used.
qk = add(qk,
params.relative_attention_bias[hi * params.relative_attention_bias_stride
* params.relative_attention_bias_stride
+ (tlength - padd_len) * params.relative_attention_bias_stride
+ (tlength - padd_len)]);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
qk_max = qk;
qk_smem[tlength - first_step] = qk;
// qk_smem[params.timestep] = qk;
}
// Make sure the data is in shared memory.
__syncthreads();
// The type of queries and keys for the math in the Q*K^T product.
using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
// The number of elements per vector.
constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
// The number of elements per thread.
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
// The number of vectors per thread.
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
// The position the first key loaded by each thread from the cache buffer (for this B * H).
int ko = tidx / THREADS_PER_KEY;
// The position of the thread in the chunk of keys.
int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
// Load the Q values from shared memory. The values are reused during the loop on K.
K_vec q_vec[K_VECS_PER_THREAD];
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
}
K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
if (DO_CROSS_ATTENTION && params.timestep == 0) {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
}
}
// The number of timesteps loaded per iteration.
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
// The number of keys per warp.
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer.
T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T* k_cache_batch = &params.k_cache[bbhi_kv * params.memory_max_len * Dh + ki];
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
// prefix prompt length if has
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const bool has_beams = params.cache_indir != nullptr;
const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len;
// The keys loaded from the key cache.
K_vec k[K_VECS_PER_THREAD];
K_vec k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.memory_max_len + ti_circ;
// if( ti < params.timestep ) {
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
if (ti < tlength) {
if (!within_bounds) {
k[ii] = k_vec_zero;
}
else {
if (has_beams) {
const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
}
else {
k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
}
}
// add bias and update k_cache
if (DO_CROSS_ATTENTION && params.timestep == 0) {
k[ii] = add(k[ii], k_bias_vec[ii]);
if (do_ia3) {
k[ii] = mul<K_vec, K_vec, K_vec>(
k[ii],
*reinterpret_cast<const K_vec*>(
&params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
+ ii * THREADS_PER_KEY * K_VEC_SIZE]));
}
if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
*reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
}
}
}
}
// Perform the dot product and normalize qk.
//
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
// Store the product to shared memory. There's one qk value per timestep. Update the max.
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
if (params.relative_attention_bias != nullptr) {
qk = add(qk,
params.relative_attention_bias[hi * params.relative_attention_bias_stride
* params.relative_attention_bias_stride
+ tlength * params.relative_attention_bias_stride + ti]);
}
if (params.linear_bias_slopes != nullptr) {
// Apply the linear position bias: (ki - qi) * slope[hi].
// The padding token locates between the input context and the generated tokens.
// We need to remove the number of padding tokens in the distance computation.
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
qk += mul<float, float, float>(params.linear_bias_slopes[hi], dist);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = qk;
}
}
// Perform the final reduction to compute the max inside each warp.
//
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
// group so it's not needed to run the reduction inside the group (again).
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Decompose the thread index into warp and lane.
const int warp = tidx / WARP_SIZE;
const int lane = tidx % WARP_SIZE;
// The warp leader writes the max to shared memory.
if (lane == 0) {
red_smem[warp] = qk_max;
}
// Make sure the products are in shared memory.
__syncthreads();
// The warps finalize the reduction.
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast to all the threads in the warp.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Compute the logits and start the sum.
float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
sum += logit;
qk_smem[ti - first_step] = logit;
}
// Compute the sum.
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// Normalize the logits.
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
const size_t cross_attention_out_offset =
params.is_return_cross_attentions ?
bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
0;
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
float logit = qk_smem[ti - first_step] * inv_sum;
if (params.is_return_cross_attentions) {
params.cross_attention_out[cross_attention_out_offset + ti] = logit;
}
convert_from_float(logits_smem[ti - first_step], logit);
}
// Put Values part below so we leverage __syncthreads
// from the previous step
// The number of elements per vector.
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
// A vector of V elements for the current timestep.
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
// The value computed by this thread.
int vo = tidx / THREADS_PER_VALUE;
// The hidden dimensions computed by this particular thread.
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer.
T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T* v_cache_batch = &params.v_cache[bbhi_kv * params.memory_max_len * Dh + vi];
// The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
// One group of threads computes the product(s) for the current timestep.
V_vec v_bias;
zero(v_bias);
// if( vo == params.timestep % V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
if (handle_kv) {
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
}
if (DO_CROSS_ATTENTION) {
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads();
// Values continued
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
#else
using V_vec_acum = V_vec;
#endif
// The partial outputs computed by each thread.
V_vec_acum out;
zero(out);
// Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
const int ti_circ = ti % params.memory_max_len;
// Fetch offset based on cache_indir when beam sampling
const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
// Load the values from the cache.
V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
if (DO_CROSS_ATTENTION && params.timestep == 0) {
v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
if (do_ia3) {
v = mul<V_vec, V_vec, V_vec>(
v,
*reinterpret_cast<const V_vec*>(
&params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
}
*reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
}
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti - first_step];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti - first_step];
// Update the partial sums.
out = fma(logit, v, out);
#endif
}
}
// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
V_vec v;
if (DO_CROSS_ATTENTION) {
v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
}
else {
// Trigger the loads from the V buffer.
const auto v_offset = v_base_offset + vi;
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
const auto v_scaling = params.qkv_scale_out[2];
const auto v_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
}
else {
v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
}
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
}
// Compute the V values with bias.
v = add(v, v_bias);
if (write_kv_cache) {
if (do_ia3) {
v = mul<V_vec, V_vec, V_vec>(
v,
*reinterpret_cast<const V_vec*>(
&params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
#else
// out = fma(logits_smem[params.timestep], v, out);
out = fma(logits_smem[tlength - first_step], v, out);
#endif
}
// Make sure we can start writing to shared memory.
__syncthreads();
// Run the final reduction amongst the different groups computing different partial outputs.
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
// The midpoint in the number of active groups.
int midpoint = active_groups / 2;
// The upper part of active threads store to shared memory.
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
#else
*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
// The bottom warps update their values.
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
}
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
cast_to_int8(out);
}
else {
convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
}
#else
// TODO: support int8_mode?
*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace mmha
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <stdint.h>
using namespace fastertransformer;
namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Float4_ {
float2 x;
float2 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct num_elems;
template<>
struct num_elems<float> {
static constexpr int value = 1;
};
template<>
struct num_elems<float2> {
static constexpr int value = 2;
};
template<>
struct num_elems<float4> {
static constexpr int value = 4;
};
template<>
struct num_elems<Float4_> {
static constexpr int value = 4;
};
template<>
struct num_elems<Float8_> {
static constexpr int value = 8;
};
template<>
struct num_elems<uint32_t> {
static constexpr int value = 2;
};
template<>
struct num_elems<uint2> {
static constexpr int value = 4;
};
template<>
struct num_elems<uint4> {
static constexpr int value = 8;
};
#ifdef ENABLE_BF16
template<>
struct num_elems<__nv_bfloat162> {
static constexpr int value = 2;
};
template<>
struct num_elems<bf16_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<bf16_8_t> {
static constexpr int value = 8;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int N>
struct packed_type;
template<typename T>
struct packed_type<T, 1> {
using type = T;
};
template<>
struct packed_type<int8_t, 2> {
using type = int16_t;
};
template<>
struct packed_type<int8_t, 4> {
using type = int32_t;
};
template<>
struct packed_type<int8_t, 8> {
using type = int64_t;
};
template<>
struct packed_type<float, 2> {
using type = float2;
};
template<>
struct packed_type<float, 4> {
using type = float4;
};
template<>
struct packed_type<float, 8> {
using type = Float8_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, float b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(float2 a, float2 b)
{
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 add(float4 a, float4 b)
{
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t add(uint16_t a, uint16_t b)
{
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t add(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 add(uint2 a, uint2 b)
{
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 add(uint4 a, uint4 b)
{
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t float_to_half(float f)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
float zero = 0.f;
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#endif
return tmp.u16[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t float2_to_half2(float2 f)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float half_to_float(uint16_t h)
{
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 half2_to_float2(uint32_t v)
{
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, uint16_t b)
{
return a + half_to_float(b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float add(float a, __nv_bfloat16 b)
{
return a + __bfloat162float(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(uint32_t a, float2 fb)
{
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(uint2 a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(uint4 a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t h0_h0(uint16_t a)
{
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(float a, float b, float c)
{
return a * b + c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(float2 a, float2 b, float2 c)
{
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(float a, float2 b, float2 c)
{
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 fma(float4 a, float4 b, float4 c)
{
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 fma(float a, float4 b, float4 c)
{
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
{
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
{
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
{
float2 fa = bf1622float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
{
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
{
return fma(h0_h0(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
{
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
{
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
{
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
{
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(uint16_t a, uint16_t b, float fc)
{
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
{
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
{
return fma(h0_h0(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
{
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
{
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(bf162bf162(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
{
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
{
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
{
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
{
return fma(bf162bf162(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b)
{
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul<float, float>(float a, float b)
{
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float2 a, float2 b)
{
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float a, float2 b)
{
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float4 a, float4 b)
{
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float a, float4 b)
{
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float a, Float8_ b)
{
Float8_ c;
c.x = make_float2(a * b.x.x, a * b.x.y);
c.y = make_float2(a * b.y.x, a * b.y.y);
c.z = make_float2(a * b.z.x, a * b.z.y);
c.w = make_float2(a * b.w.x, a * b.w.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b)
{
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b)
{
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint2 a, uint2 b)
{
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b)
{
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint4 a, uint4 b)
{
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b)
{
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, uint16_t b)
{
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, float b)
{
return half_to_float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b)
{
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b)
{
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b)
{
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b)
{
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b)
{
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b)
{
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
float fa = (float)a;
float fb = (float)b;
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, float b)
{
return __bfloat162float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
{
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
{
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float v)
{
return v;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float2 v)
{
return v.x + v.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float4 v)
{
return v.x + v.y + v.z + v.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v)
{
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_4_t v)
{
return sum(v.x) + sum(v.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_8_t v)
{
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint16_t v)
{
return half_to_float(v);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint32_t v)
{
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint2 v)
{
uint32_t c = add(v.x, v.y);
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint4 v)
{
#if 1
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
#else
uint32_t c = add(v.x, v.y);
uint32_t d = add(v.z, v.w);
c = add(c, d);
#endif
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float4_ v)
{
return v.x.x + v.x.y + v.y.x + v.y.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float8_ v)
{
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float dot(T a, T b)
{
return sum(mul<T, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename A, typename T>
inline __device__ float dot(T a, T b)
{
return sum(mul<A, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void zero(uint16_t& dst)
{
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void zero(T& dst)
{
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
{
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)};
}
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
{
float2 rot_v;
rot_v.x = coef.x * v.x - coef.y * v.y;
rot_v.y = coef.x * v.y + coef.y * v.x;
return rot_v;
}
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
{
float2 fv = half2_to_float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return float2_to_half2(rot_fv);
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
{
float2 fv = bf1622float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
#endif
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
return;
}
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
return;
}
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16
template<typename Vec_T, typename T>
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
vec = tmp_3.u32x2;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
tmp_3.u16[4] = tmp_1.u16[2];
tmp_3.u16[5] = tmp_2.u16[2];
tmp_3.u16[6] = tmp_1.u16[3];
tmp_3.u16[7] = tmp_2.u16[3];
vec = tmp_3.u32x4;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
__nv_bfloat16 bf16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
}
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
__nv_bfloat16 bf16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
}
#endif // ENABLE_BF16
template<>
__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.z = smem[transpose_idx + 1];
vec.y = smem[smem_pitch + transpose_idx];
vec.w = smem[smem_pitch + transpose_idx + 1];
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
#endif
template<>
__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
template<typename Vec_T, typename T>
__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u32x4 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
tmp_1.u16[2] = tmp_3.u16[4];
tmp_2.u16[2] = tmp_3.u16[5];
tmp_1.u16[3] = tmp_3.u16[6];
tmp_2.u16[3] = tmp_3.u16[7];
*reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
*reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u32x2 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
*reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
*reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
template<>
__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[transpose_idx + 1] = vec.z;
smem[smem_pitch + transpose_idx] = vec.y;
smem[smem_pitch + transpose_idx + 1] = vec.w;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
#endif
template<>
__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
} // namespace mmha
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template<typename T>
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
const cudaStream_t& stream);
template<typename T>
void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
const cudaStream_t& stream);
template<typename T>
struct SATypeConverter {
using Type = T;
};
template<>
struct SATypeConverter<at::Half> {
using Type = uint16_t;
};
template<>
struct SATypeConverter<at::BFloat16> {
using Type = __nv_bfloat16;
};
template <typename T>
void set_params(Masked_multihead_attention_params<T> &params,
const size_t batch_size,
const size_t nheads,
const size_t nheads_kv,
const size_t memory_max_seqlen,
const size_t headdim,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
const bool neox_rotary_style,
const int qkv_batch_stride,
T *q_ptr,
T *k_ptr,
T *v_ptr,
T *k_cache_ptr,
T *v_cache_ptr,
int *length_per_sample,
float *alibi_slopes_ptr,
T *out_ptr) {
// Reset the parameters
memset(&params, 0, sizeof(params));
params.q = q_ptr;
params.k = k_ptr;
params.v = v_ptr;
params.q_bias = nullptr;
params.k_bias = nullptr;
params.v_bias = nullptr;
params.k_cache = k_cache_ptr;
params.v_cache = v_cache_ptr;
params.linear_bias_slopes = alibi_slopes_ptr;
params.out = out_ptr;
params.cache_indir = nullptr;
params.stride = qkv_batch_stride;
params.batch_size = batch_size;
params.beam_width = 1;
params.memory_max_len = memory_max_seqlen;
params.num_heads = nheads;
params.num_kv_heads = nheads_kv;
params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_base = rotary_base;
params.neox_rotary_style = neox_rotary_style;
params.timestep = timestep;
params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
params.total_padding_tokens = nullptr;
params.masked_tokens = nullptr;
params.prefix_prompt_lengths = nullptr;
params.max_prefix_prompt_length = 0;
params.relative_attention_bias = nullptr;
params.relative_attention_bias_stride = 0;
params.cross_attention_out = nullptr;
params.max_decoder_seq_len = 0;
params.is_return_cross_attentions = false;
params.finished = nullptr;
params.memory_length_per_sample = nullptr;
params.length_per_sample = length_per_sample;
}
torch::Tensor single_query_attention(const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
torch::Tensor k_cache,
torch::Tensor v_cache,
c10::optional<const torch::Tensor> length_per_sample_,
c10::optional<const torch::Tensor> alibi_slopes_,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
// neox_rotary_style = not interleaved
const bool neox_rotary_style) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0);
int nheads = q.size(1);
int nheads_kv = v_cache.size(1);
int memory_max_seqlen = v_cache.size(2);
int headdim = v_cache.size(3);
CHECK_SHAPE(q, batch_size, nheads, headdim);
CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
if (length_per_sample_.has_value()) {
auto length_per_sample = length_per_sample_.value();
CHECK_DEVICE(length_per_sample);
CHECK_SHAPE(length_per_sample, batch_size);
CHECK_CONTIGUOUS(length_per_sample);
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
}
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
CHECK_SHAPE(alibi_slopes, nheads);
CHECK_CONTIGUOUS(alibi_slopes);
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
torch::Tensor out = torch::empty_like(q);
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim,
timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()),
reinterpret_cast<DataType*>(v.data_ptr()),
reinterpret_cast<DataType*>(k_cache.data_ptr()),
reinterpret_cast<DataType*>(v_cache.data_ptr()),
length_per_sample_.has_value()
? length_per_sample_.value().data_ptr<int>() : nullptr,
alibi_slopes_.has_value()
? alibi_slopes_.value().data_ptr<float>(): nullptr,
reinterpret_cast<DataType*>(out.data_ptr()));
auto stream = at::cuda::getCurrentCUDAStream();
masked_multihead_attention(params, stream);
});
return out;
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
torch::Tensor single_query_attention(const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
torch::Tensor k_cache,
torch::Tensor v_cache,
c10::optional<const torch::Tensor> length_per_sample_,
c10::optional<const torch::Tensor> alibi_slopes_,
const int timestep,
const int rotary_embedding_dim = 0,
const float rotary_base = 10000.0f,
const bool neox_rotary_style=true);
\ No newline at end of file
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
*/
#include <torch/extension.h>
#include <cuda_fp16.h>
#include "reduction.cuh"
#include "layernorm.h"
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
static inline __device__ float to_float(half src)
{
return __half2float(src);
}
static inline __device__ float to_float(float src)
{
return src;
}
template<typename T>
__global__ void generalT5LayerNorm(
const T* __restrict input, const T* __restrict gamma, T* output, const float layernorm_eps, int m, int n)
{
// layernorm module in the T5 style No bias and no subtraction of mean.
const int tid = threadIdx.x;
__shared__ float s_variance;
float variance = 0.0f;
float local_var_sum = 0.0f;
for (int i = tid; i < n; i += blockDim.x) {
float diff = to_float(__ldg(&input[blockIdx.x * n + i]));
local_var_sum += diff * diff;
}
variance = blockReduceSum(local_var_sum);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (float)n + layernorm_eps);
}
__syncthreads();
for (int i = tid; i < n; i += blockDim.x) {
output[blockIdx.x * n + i] =
clamp_inf_for_half<T>((to_float(input[blockIdx.x * n + i]) * s_variance) * to_float(__ldg(&gamma[i])));
}
}
template<typename T>
void invokeGeneralT5LayerNorm(T* out,
const T* input,
const T* gamma,
// const T* beta,
const float layernorm_eps,
const int m,
const int n)
{
dim3 grid(m);
dim3 block(min(n, 1024));
/* For general cases, n is equal to hidden_units, e.g., 512/1024.
Since we have warp shuffle inside the code, block.x % 32 should be 0.
*/
if (n % 32 != 0) {
block.x = 1024;
}
block.x = block.x / (4 / sizeof(T)); // if using half, only need half of block.x
/* should pay attention to the rsqrt precision*/
generalT5LayerNorm<T><<<grid, block>>>(input, gamma, out, layernorm_eps, m, n); // For gpt-3
}
template void invokeGeneralT5LayerNorm(half* out,
const half* input,
const half* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);
template void invokeGeneralT5LayerNorm(float* out,
const float* input,
const float* gamma,
// const half* beta,
const float layernorm_eps,
const int m,
const int n);
// input b, n, c
void layernorm_forward_cuda(
torch::Tensor _input,
torch::Tensor _gamma,
torch::Tensor _out,
float eps)
{
int m = _input.size(0) * _input.size(1);
int n = _input.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_input));
auto input = reinterpret_cast<half*>(_input.data_ptr<at::Half>());
auto gamma = reinterpret_cast<half*>(_gamma.data_ptr<at::Half>());
auto out = reinterpret_cast<half*>(_out.data_ptr<at::Half>());
invokeGeneralT5LayerNorm(out, input, gamma, eps, m, n);
}
#include <torch/extension.h>
void layernorm_forward_cuda(torch::Tensor _input, torch::Tensor _gamma, torch::Tensor _out, float eps);
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
*/
#pragma once
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <float.h>
#include <type_traits>
#define HALF_FLT_MAX 65504.F
#define FINAL_MASK 0xffffffff
template<typename T>
inline __device__ T add(T a, T b) {
return a + b;
}
template<>
inline __device__ half2 add(half2 a, half2 b) {
return __hadd2(a, b);
}
template<>
inline __device__ half add(half a, half b) {
return __hadd(a, b);
}
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template<typename T>
__device__ __forceinline__ T clamp_inf_for_half(const float input)
{
return input;
}
template<>
__device__ __forceinline__ half clamp_inf_for_half(const float input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? __float2half(min(input, HALF_FLT_MAX - 1000)) : __float2half(max(input, -HALF_FLT_MAX + 1000));
}
#pragma once
#include <torch/extension.h>
void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
\ No newline at end of file
/*
Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}
void rotary_embedding_neox(
torch::Tensor& positions, // [b, num_tokens]
torch::Tensor& query, // [b, num_tokens, 1, num_heads, head_size]
torch::Tensor& key, // [b, num_tokens, 1, num_heads, head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0) * query.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-2);
int stride = num_heads * head_size;
// TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);
});
}
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
}
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
}
\ No newline at end of file
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}
#include <torch/extension.h>
torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters);
torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros, int group_size, int split_k_iters);
\ No newline at end of file
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)];
int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#else
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#else
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
template <int G>
__global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[64];
__shared__ half A_shared[128 * (32 + 8)];
__shared__ half B_shared[64 * (32 + 8)];
// __shared__ half scaling_factors_shared[64];
// __shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1);
half A_shared_warp[32];
half B_shared_warp[16];
for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) {
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0;
}
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride_A = 4 * 32 * 8 / 32;
static constexpr int row_stride = 4 * 32 * 8 / 32;
const int make_divisible_multipler = 128 / G;
const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler;
const int sf_w = zeros_w * 8;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (IC / 8) * 8
+ (((int)threadIdx.x) / (32 / 8)) * (IC / 8)
+ (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8)
+ (((int)threadIdx.x) % (32 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 4) * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* zeros_ptr = zeros
+ ((int)threadIdx.y) * zeros_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * zeros_w
+ (((int)blockIdx_y) % j_factors1) * 64 * zeros_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) / G ;
half* scaling_factors_ptr = scaling_factors
+ ((int)threadIdx.y) * sf_w * 8
+ (((int)threadIdx.x) / (32 / 8)) * sf_w
+ (((int)blockIdx_y) % j_factors1) * (64) * sf_w
// this term is zero
+ (((int)threadIdx.x) % (32 / 8)) * 8 / G;
// Haotian: TBD, check, May 29 11:46 AM PST
half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdx_z -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ (((int)threadIdx.y) / 2) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
// TODO (Haotian): load scales and zero points to smem
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: Here we assume M % cta_M = 0.
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0)
{
if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M)
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0);
}
}
int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8;
half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G;
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * (32 / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8));
int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w);
zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4;
float current_zeros = (float)(zeros_loaded & 0xF);
half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w);
half B_loaded_fp16[8];
#pragma unroll
for (int ic_1 = 0; ic_1 < 8; ic_1++){
float current_single_weight_fp = (float)(B_loaded_current & 0xF);
half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros));
B_loaded_current = B_loaded_current >> 4;
B_loaded_fp16[ic_1] = dequantized_weight;
}
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast<uint4*>(B_loaded_fp16);
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) {
{
unsigned int addr;
asm volatile(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))
);
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3])
: "r"(addr)
);
}
}
for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) {
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]));
}
#else
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
}
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]));
}
#endif
}
}
}
}
// Haotian: Here (May 29 11:46AM PST)
// TODO: Shang: Hoist loop invariance.
for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) {
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4;
if (row_offset < M)
{
*(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]);
}
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor gemmv2_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
// for int4, need _kernel.size(1) * 8
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(0)}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
// blockIdx_x: i_factors[0] * j_factors[0]
// blockIdx_y: i_factors[1] * j_factors[1]
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 4);
if (group_size == 128)
{
gemmv2_forward_4bit_cuda_m128n64k32<128><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (group_size == 64)
{
gemmv2_forward_4bit_cuda_m128n64k32<64><<<num_blocks, threads_per_block>>>(
split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else
{
throw std::invalid_argument("Group size temporarily not supported.");
}
return _out_feats.sum(0);
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment