Commit 000b67f5 authored by Tri Dao's avatar Tri Dao
Browse files

Use int64_t instead of uint32_t for index_t

parent e43a4cea
......@@ -22,7 +22,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
......@@ -99,7 +99,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
int * __restrict__ cache_batch_idx;
// The dropout probability (probability of keeping an activation).
float p_dropout;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment