Commit f266fc72 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen, FT] Use tlength instead of params.timestep for rotary

parent a01d1213
...@@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) { if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len);
} }
else { else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len);
} }
} }
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
...@@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
} }
else { else {
mmha::apply_rotary_embedding( mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
} }
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment