Commit 3a9bfd07 authored by Tri Dao's avatar Tri Dao
Browse files

[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)

parent e8a0b4ac
......@@ -1549,7 +1549,7 @@ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
}
......@@ -1558,7 +1558,7 @@ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
......@@ -1570,9 +1570,9 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
......@@ -1584,10 +1584,10 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
......@@ -1597,7 +1597,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embe
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
}
......@@ -1606,7 +1606,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid,
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
......@@ -1616,9 +1616,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
}
......@@ -1627,10 +1627,10 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
......@@ -1640,13 +1640,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
q.w = rotary_embedding_transform(q.w, coef3);
}
......@@ -1655,16 +1655,16 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
......@@ -1675,7 +1675,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
}
......@@ -1684,7 +1684,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162&
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
......@@ -1694,9 +1694,9 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embe
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
}
......@@ -1705,10 +1705,10 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid,
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
......@@ -1718,13 +1718,13 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embe
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
q.w = rotary_embedding_transform(q.w, coef3);
}
......@@ -1733,16 +1733,16 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos + t_step * rot_embed_dim / 2, rotary_sin + t_step * rot_embed_dim / 2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
......
......@@ -160,16 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
if (rotary_cos_.has_value()) {
auto rotary_cos = rotary_cos_.value();
CHECK_DEVICE(rotary_cos);
int rotary_seqlen = rotary_cos.size(0);
rotary_embedding_dim = rotary_cos.size(1) * 2;
CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
rotary_embedding_dim = rotary_cos.size(0) * 2;
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_cos);
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
TORCH_CHECK(rotary_sin_.has_value());
auto rotary_sin = rotary_sin_.value();
CHECK_DEVICE(rotary_sin);
CHECK_SHAPE(rotary_cos, rotary_seqlen, rotary_embedding_dim / 2);
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_sin);
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment