Commit 62e98144 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Make sure frequency calculation is in fp32

parent 9818f85f
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
if (smem_sz >= 48 * 1024) { \ if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \ } \
dim3 grid(params.num_heads, params.batch_size); \ dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params) kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -113,6 +113,12 @@ struct Multihead_attention_params_base { ...@@ -113,6 +113,12 @@ struct Multihead_attention_params_base {
const float* qkv_scale_out = nullptr; const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
const T *rotary_cos = nullptr;
const T *rotary_sin = nullptr;
const int *nnz_head_idx = nullptr;
int nnz_heads = 0;
}; };
template<typename T, bool CROSS_ATTENTION> template<typename T, bool CROSS_ATTENTION>
......
...@@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The "beam-aware" batch idx // The "beam-aware" batch idx
const int bbi = bi / params.beam_width; const int bbi = bi / params.beam_width;
// The head. // The head.
const int hi = blockIdx.x; // const int hi = blockIdx.x;
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
// Combine the batch and the head indices. // Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi; const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices. // Combine the "beam-aware" batch idx and the head indices.
...@@ -1061,10 +1062,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1061,10 +1062,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) { if (handle_kv) {
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}
} }
else { else {
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}
} }
} }
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
...@@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if (handle_kv) { if (handle_kv) {
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
if (params.rotary_cos == nullptr) {
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
}
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
} }
else { else {
if (params.rotary_cos == nullptr) {
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_cos, params.rotary_sin);
}
} }
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
} }
......
...@@ -1272,10 +1272,10 @@ inline __device__ void zero(T& dst) ...@@ -1272,10 +1272,10 @@ inline __device__ void zero(T& dst)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base) inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base)
{ {
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)}; return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)};
} }
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
...@@ -1447,8 +1447,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro ...@@ -1447,8 +1447,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
...@@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, ...@@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
} }
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template <typename T>
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])};
}
// fp16 is special because we use uint16_t for reading the data, for backward compatibility.
template <>
inline __device__ float2 rotary_embedding_coefficient<uint16_t>(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return {float(reinterpret_cast<const __half*>(rotary_cos)[zid / 2]),
float(reinterpret_cast<const __half*>(rotary_sin)[zid / 2])};
}
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
return;
}
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
return;
}
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16
template<typename Vec_T, typename T> template<typename Vec_T, typename T>
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); __device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
......
...@@ -57,13 +57,17 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -57,13 +57,17 @@ void set_params(Masked_multihead_attention_params<T> &params,
const float rotary_base, const float rotary_base,
const bool neox_rotary_style, const bool neox_rotary_style,
const int qkv_batch_stride, const int qkv_batch_stride,
const int nnz_heads,
T *q_ptr, T *q_ptr,
T *k_ptr, T *k_ptr,
T *v_ptr, T *v_ptr,
T *k_cache_ptr, T *k_cache_ptr,
T *v_cache_ptr, T *v_cache_ptr,
int *length_per_sample, int *length_per_sample,
T *out_ptr) { T *rotary_cos,
T *rotary_sin,
T *out_ptr,
int *nnz_head_idx) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
params.q = q_ptr; params.q = q_ptr;
...@@ -81,6 +85,7 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -81,6 +85,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.beam_width = 1; params.beam_width = 1;
params.memory_max_len = memory_max_seqlen; params.memory_max_len = memory_max_seqlen;
params.num_heads = nheads; params.num_heads = nheads;
params.nnz_heads = nnz_heads;
params.hidden_size_per_head = headdim; params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_base = rotary_base; params.rotary_base = rotary_base;
...@@ -99,6 +104,9 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -99,6 +104,9 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.finished = nullptr; params.finished = nullptr;
params.memory_length_per_sample = nullptr; params.memory_length_per_sample = nullptr;
params.length_per_sample = length_per_sample; params.length_per_sample = length_per_sample;
params.rotary_cos = rotary_cos;
params.rotary_sin = rotary_sin;
params.nnz_head_idx = nnz_head_idx;
} }
torch::Tensor single_query_attention(const torch::Tensor q, torch::Tensor single_query_attention(const torch::Tensor q,
...@@ -107,8 +115,11 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -107,8 +115,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
torch::Tensor k_cache, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor v_cache,
c10::optional<const torch::Tensor> length_per_sample_, c10::optional<const torch::Tensor> length_per_sample_,
c10::optional<const torch::Tensor> rotary_cos_,
c10::optional<const torch::Tensor> rotary_sin_,
c10::optional<const torch::Tensor> nnz_head_idx_,
const int timestep, const int timestep,
const int rotary_embedding_dim = 0, int rotary_embedding_dim = 0,
const float rotary_base = 10000.0f, const float rotary_base = 10000.0f,
const bool neox_rotary_style=true) { const bool neox_rotary_style=true) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
...@@ -116,6 +127,9 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -116,6 +127,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
int nheads = v_cache.size(1); int nheads = v_cache.size(1);
int memory_max_seqlen = v_cache.size(2); int memory_max_seqlen = v_cache.size(2);
int headdim = v_cache.size(3); int headdim = v_cache.size(3);
auto input_type = q.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
CHECK_SHAPE(q, batch_size, nheads, headdim); CHECK_SHAPE(q, batch_size, nheads, headdim);
CHECK_SHAPE(k, batch_size, nheads, headdim); CHECK_SHAPE(k, batch_size, nheads, headdim);
CHECK_SHAPE(v, batch_size, nheads, headdim); CHECK_SHAPE(v, batch_size, nheads, headdim);
...@@ -129,6 +143,12 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -129,6 +143,12 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
TORCH_CHECK(q.scalar_type() == input_type);
TORCH_CHECK(k.scalar_type() == input_type);
TORCH_CHECK(v.scalar_type() == input_type);
TORCH_CHECK(k_cache.scalar_type() == input_type);
TORCH_CHECK(v_cache.scalar_type() == input_type);
if (length_per_sample_.has_value()) { if (length_per_sample_.has_value()) {
auto length_per_sample = length_per_sample_.value(); auto length_per_sample = length_per_sample_.value();
CHECK_DEVICE(length_per_sample); CHECK_DEVICE(length_per_sample);
...@@ -137,6 +157,32 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -137,6 +157,32 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
} }
if (rotary_cos_.has_value()) {
auto rotary_cos = rotary_cos_.value();
CHECK_DEVICE(rotary_cos);
int rotary_seqlen = rotary_cos.size(0);
rotary_embedding_dim = rotary_cos.size(1) * 2;
CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_cos);
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
TORCH_CHECK(rotary_sin_.has_value());
auto rotary_sin = rotary_sin_.value();
CHECK_DEVICE(rotary_sin);
CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_sin);
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
}
if (nnz_head_idx_.has_value()) {
auto nnz_head_idx = nnz_head_idx_.value();
CHECK_DEVICE(nnz_head_idx);
int nnz_heads = nnz_head_idx.size(0);
CHECK_SHAPE(nnz_head_idx, nnz_heads);
CHECK_CONTIGUOUS(nnz_head_idx);
TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32);
}
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing // Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; at::cuda::CUDAGuard device_guard{(char)q.get_device()};
...@@ -148,6 +194,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -148,6 +194,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
Masked_multihead_attention_params<DataType> params; Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
reinterpret_cast<DataType*>(q.data_ptr()), reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()), reinterpret_cast<DataType*>(k.data_ptr()),
reinterpret_cast<DataType*>(v.data_ptr()), reinterpret_cast<DataType*>(v.data_ptr()),
...@@ -155,7 +202,13 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -155,7 +202,13 @@ torch::Tensor single_query_attention(const torch::Tensor q,
reinterpret_cast<DataType*>(v_cache.data_ptr()), reinterpret_cast<DataType*>(v_cache.data_ptr()),
length_per_sample_.has_value() length_per_sample_.has_value()
? length_per_sample_.value().data_ptr<int>() : nullptr, ? length_per_sample_.value().data_ptr<int>() : nullptr,
reinterpret_cast<DataType*>(out.data_ptr())); rotary_cos_.has_value()
? reinterpret_cast<DataType*>(rotary_cos_.value().data_ptr()) : nullptr,
rotary_sin_.has_value()
? reinterpret_cast<DataType*>(rotary_sin_.value().data_ptr()) : nullptr,
reinterpret_cast<DataType*>(out.data_ptr()),
nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr<int>() : nullptr
);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
masked_multihead_attention(params, stream); masked_multihead_attention(params, stream);
}); });
...@@ -165,6 +218,8 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -165,6 +218,8 @@ torch::Tensor single_query_attention(const torch::Tensor q,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_query_attention", &single_query_attention, "Attention with a single query", m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, py::arg("length_per_sample_"), py::arg("rotary_cos_"),
py::arg("rotary_sin_"), py::arg("nnz_head_idx_"),
py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
} }
...@@ -169,16 +169,28 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -169,16 +169,28 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
""" """
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None): def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
pos_idx_in_fp32=True, device=None):
""" """
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style). of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
""" """
super().__init__() super().__init__()
self.dim = dim
self.base = float(base) self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, inv_freq = self._compute_inv_freq(device)
dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.interleaved = interleaved self.interleaved = interleaved
self.scale_base = scale_base self.scale_base = scale_base
...@@ -192,31 +204,48 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -192,31 +204,48 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
def _update_cos_sin_cache(self, x, seqlen_offset=0): def _compute_inv_freq(self, device=None):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
""" dtype=torch.float32) / self.dim))
seqlen = x.shape[1] + seqlen_offset
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != x.dtype): or self._cos_cached.dtype != dtype):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# Don't do einsum, it converts fp32 to fp16 # And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device)) freqs = torch.outer(t, inv_freq)
if self.scale is None: if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(dtype)
else: else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base) - seqlen // 2) / self.scale_base)
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
# We want the multiplication by scale to happen in fp32 # We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
...@@ -224,7 +253,7 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -224,7 +253,7 @@ class RotaryEmbedding(torch.nn.Module):
seqlen_offset: can be used in generation where the qkv being passed in is only the last seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch. token in the batch.
""" """
self._update_cos_sin_cache(qkv, seqlen_offset) self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if self.scale is None: if self.scale is None:
return apply_rotary_emb_qkv_( return apply_rotary_emb_qkv_(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
......
...@@ -515,8 +515,13 @@ class MHA(nn.Module): ...@@ -515,8 +515,13 @@ class MHA(nn.Module):
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention( context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], k_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset, v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
self.rotary_emb_dim, rotary_emb_base, self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style # neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
...@@ -637,8 +642,13 @@ class ParallelMHA(nn.Module): ...@@ -637,8 +642,13 @@ class ParallelMHA(nn.Module):
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention( context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], k_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset, v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
self.rotary_emb_dim, rotary_emb_base, self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style # neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
......
...@@ -267,10 +267,11 @@ def test_llama_generation(model_name): ...@@ -267,10 +267,11 @@ def test_llama_generation(model_name):
del model del model
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f'HF fp16 logits max diff: {hf_error}') print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }')
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }')
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment