Commit 48bc6eac authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Add rotary base as an argument to FT attention kernel

parent 7c766b1b
...@@ -84,6 +84,7 @@ struct Multihead_attention_params_base { ...@@ -84,6 +84,7 @@ struct Multihead_attention_params_base {
// The per-head latent space reserved for rotary embeddings. // The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0; int rotary_embedding_dim = 0;
bool neox_rotary_style = false; bool neox_rotary_style = false;
float rotary_base = 0.0f;
// The maximum length of input sentences. // The maximum length of input sentences.
int max_input_length = 0; int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
......
...@@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) { if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len); apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} }
else { else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len); apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} }
} }
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
...@@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len); q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
} }
else { else {
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength); q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
} }
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
} }
......
...@@ -1272,9 +1272,9 @@ inline __device__ void zero(T& dst) ...@@ -1272,9 +1272,9 @@ inline __device__ void zero(T& dst)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step) inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
{ {
const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim); const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)}; return {cos(inv_freq), sin(inv_freq)};
} }
...@@ -1302,49 +1302,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 ...@@ -1302,49 +1302,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162
} }
#endif #endif
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
return; return;
} }
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
return; return;
} }
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef); k = rotary_embedding_transform(k, coef);
} }
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
Float4_& q_ = *reinterpret_cast<Float4_*>(&q); Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0); q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1); q_.y = rotary_embedding_transform(q_.y, coef1);
} }
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
...@@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int ...@@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_& q_ = *reinterpret_cast<Float4_*>(&q); Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k); Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0); q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1); q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1);
} }
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef); k = rotary_embedding_transform(k, coef);
} }
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
} }
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
} }
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
} }
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2); k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3); k.w = rotary_embedding_transform(k.w, coef3);
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
} }
inline __device__ void inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step) apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (2 * tid >= rot_embed_dim) { if (2 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef); q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef); k = rotary_embedding_transform(k, coef);
} }
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
} }
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (4 * tid >= rot_embed_dim) { if (4 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
} }
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
} }
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step) inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{ {
if (8 * tid >= rot_embed_dim) { if (8 * tid >= rot_embed_dim) {
return; return;
} }
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0); q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0); k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1); q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1); k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2); q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2); k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3); q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3); k.w = rotary_embedding_transform(k.w, coef3);
} }
......
...@@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
const size_t headdim, const size_t headdim,
const int timestep, const int timestep,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const float rotary_base,
const bool neox_rotary_style, const bool neox_rotary_style,
const int qkv_batch_stride, const int qkv_batch_stride,
T *q_ptr, T *q_ptr,
...@@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.num_heads = nheads; params.num_heads = nheads;
params.hidden_size_per_head = headdim; params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_base = rotary_base;
params.neox_rotary_style = neox_rotary_style; params.neox_rotary_style = neox_rotary_style;
params.timestep = timestep; params.timestep = timestep;
params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
...@@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
c10::optional<const torch::Tensor> length_per_sample_, c10::optional<const torch::Tensor> length_per_sample_,
const int timestep, const int timestep,
const int rotary_embedding_dim = 0, const int rotary_embedding_dim = 0,
const float rotary_base = 10000.0f,
const bool neox_rotary_style=true) { const bool neox_rotary_style=true) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0); int batch_size = v_cache.size(0);
...@@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
using DataType = typename SATypeConverter<scalar_t>::Type; using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params; Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, neox_rotary_style, q.stride(0), rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
reinterpret_cast<DataType*>(q.data_ptr()), reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()), reinterpret_cast<DataType*>(k.data_ptr()),
reinterpret_cast<DataType*>(v.data_ptr()), reinterpret_cast<DataType*>(v.data_ptr()),
...@@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_query_attention", &single_query_attention, "Attention with a single query", m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("neox_rotary_style")=true); py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
} }
...@@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
""" """
def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None): def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None):
""" """
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style). of 1st half and 2nd half (GPT-NeoX style).
""" """
super().__init__() super().__init__()
self.base = float(base)
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim)) dtype=torch.float32) / dim))
......
...@@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
out_proj_bias = getattr(config, 'out_proj_bias', True) out_proj_bias = getattr(config, 'out_proj_bias', True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
use_flash_attn = getattr(config, 'use_flash_attn', False) use_flash_attn = getattr(config, 'use_flash_attn', False)
...@@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt ...@@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved, rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs) **serial_kwargs, **parallel_kwargs, **factory_kwargs)
......
...@@ -350,9 +350,9 @@ class MHA(nn.Module): ...@@ -350,9 +350,9 @@ class MHA(nn.Module):
def __init__(self, embed_dim, num_heads, cross_attn=False, def __init__(self, embed_dim, num_heads, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
fused_bias_fc=False, use_flash_attn=False, return_residual=False, rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
checkpointing=False, device=None, dtype=None) -> None: return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
""" """
return_residual: whether to return the input x along with the output. This is for return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us performance reason: for post-norm architecture, returning the input allows us
...@@ -377,7 +377,8 @@ class MHA(nn.Module): ...@@ -377,7 +377,8 @@ class MHA(nn.Module):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed' assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device) interleaved=rotary_emb_interleaved, device=device)
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
...@@ -511,11 +512,12 @@ class MHA(nn.Module): ...@@ -511,11 +512,12 @@ class MHA(nn.Module):
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None) if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention( context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset, lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim, self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style # neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
) )
...@@ -555,8 +557,8 @@ class ParallelMHA(nn.Module): ...@@ -555,8 +557,8 @@ class ParallelMHA(nn.Module):
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True, def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
use_flash_attn=False, checkpointing=False, rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None: sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
...@@ -573,7 +575,8 @@ class ParallelMHA(nn.Module): ...@@ -573,7 +575,8 @@ class ParallelMHA(nn.Module):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed' assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device) interleaved=rotary_emb_interleaved, device=device)
if ColumnParallelLinear is None or RowParallelLinear is None: if ColumnParallelLinear is None or RowParallelLinear is None:
...@@ -631,11 +634,12 @@ class ParallelMHA(nn.Module): ...@@ -631,11 +634,12 @@ class ParallelMHA(nn.Module):
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None) if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention( context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset, lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim, inference_params.sequence_len_offset, self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style # neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
) )
......
...@@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): ...@@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
if rotary: if rotary:
config.n_positions = 0 config.n_positions = 0
config.rotary_emb_dim = 64 config.rotary_emb_fraction = 0.5
config.rotary_emb_base = 24000
config.residual_in_fp32 = True config.residual_in_fp32 = True
if optimized: if optimized:
config.use_flash_attn = True config.use_flash_attn = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment