Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
0da7b5db
Commit
0da7b5db
authored
Dec 03, 2025
by
PanZezhong
Browse files
issue/97 Attention 和 KVCache 支持 batch 维度
parent
a1f6e517
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
115 additions
and
152 deletions
+115
-152
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+71
-43
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+44
-109
No files found.
csrc/cache/kv_cache.hpp
View file @
0da7b5db
#pragma once
#pragma once
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <numeric>
#include <stdexcept>
#include <utility>
#include <utility>
namespace
infinilm
::
cache
{
namespace
infinilm
::
cache
{
...
@@ -11,26 +15,20 @@ namespace infinilm::cache {
...
@@ -11,26 +15,20 @@ namespace infinilm::cache {
/**
/**
* @brief Simple KV cache structure for incremental decoding
* @brief Simple KV cache structure for incremental decoding
*
*
* Stores key and value caches with shape [n_kv_head, capacity, head_dim]
* Stores key and value caches with shape [
batch_size,
n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
* Similar to DynamicLayer in Python cache_utils.py
*
*
* This is a common component that can be used by any model architecture
* This is a common component that can be used by any model architecture
* that needs KV caching for attention mechanisms.
* that needs KV caching for attention mechanisms.
*/
*/
struct
KVCache
{
struct
KVCache
{
infinicore
::
Tensor
k_cache
;
// [n_kv_head, capacity, head_dim]
infinicore
::
Tensor
k_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
infinicore
::
Tensor
v_cache
;
// [n_kv_head, capacity, head_dim]
infinicore
::
Tensor
v_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
size_t
cache_position
;
// Current position in cache
std
::
vector
<
size_t
>
cache_positions
;
// Current position in cache
size_t
max_capacity
;
// Maximum capacity of cache
size_t
max_capacity
;
// Maximum capacity of cache
bool
initialized
;
// Whether cache has been initialized
bool
initialized
;
// Whether cache has been initialized
KVCache
()
KVCache
()
:
max_capacity
(
0
),
initialized
(
false
)
{}
:
cache_position
(
0
),
max_capacity
(
0
),
initialized
(
false
),
// Create empty placeholder tensors (will be replaced on first use)
k_cache
(
infinicore
::
Tensor
::
empty
({
1
,
1
,
1
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
(
infinicore
::
Device
::
Type
::
CPU
,
0
))),
v_cache
(
infinicore
::
Tensor
::
empty
({
1
,
1
,
1
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
(
infinicore
::
Device
::
Type
::
CPU
,
0
)))
{}
/**
/**
* @brief Initialize or update cache capacity
* @brief Initialize or update cache capacity
...
@@ -40,34 +38,44 @@ struct KVCache {
...
@@ -40,34 +38,44 @@ struct KVCache {
* @param dtype Data type
* @param dtype Data type
* @param device Device
* @param device Device
*/
*/
void
ensure_capacity
(
size_t
num_kv_heads
,
size_t
head_dim
,
size_t
seq_len
,
void
ensure_capacity
(
size_t
batch_size
,
size_t
num_kv_heads
,
size_t
head_dim
,
size_t
seq_len
,
infinicore
::
DataType
dtype
,
const
infinicore
::
Device
&
device
)
{
infinicore
::
DataType
dtype
,
const
infinicore
::
Device
&
device
)
{
size_t
required_capacity
=
cache_position
+
seq_len
;
size_t
required_capacity
=
seq_len
+
std
::
accumulate
(
cache_positions
.
begin
(),
cache_positions
.
end
(),
0
,
[](
int
a
,
int
b
)
{
return
std
::
max
(
a
,
b
);
})
;
// Lazy initialization
// Lazy initialization
if
(
!
initialized
)
{
if
(
!
initialized
)
{
max_capacity
=
std
::
max
(
required_capacity
,
size_t
(
4096
));
// Start with at least 4096
max_capacity
=
std
::
max
(
required_capacity
,
size_t
(
4096
));
// Start with at least 4096
k_cache
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
max_capacity
,
head_dim
},
k_cache
=
infinicore
::
Tensor
::
empty
({
batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
v_cache
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
max_capacity
,
head_dim
},
v_cache
=
infinicore
::
Tensor
::
empty
({
batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
cache_position
=
0
;
cache_position
s
=
std
::
vector
<
size_t
>
(
batch_size
,
0
)
;
initialized
=
true
;
initialized
=
true
;
}
}
// Grow cache if needed (similar to DynamicLayer in Python)
// Grow cache if needed (similar to DynamicLayer in Python)
else
if
(
required_capacity
>
max_capacity
)
{
else
if
(
required_capacity
>
max_capacity
)
{
size_t
new_capacity
=
std
::
max
(
max_capacity
*
2
,
required_capacity
);
size_t
new_capacity
=
std
::
max
(
max_capacity
*
2
,
required_capacity
+
max_capacity
);
auto
k_new
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
new_capacity
,
head_dim
},
size_t
new_batch_size
=
std
::
max
(
batch_size
,
k_cache
->
shape
()[
0
]);
if
(
num_kv_heads
!=
k_cache
->
shape
()[
1
]
||
head_dim
!=
k_cache
->
shape
()[
3
])
{
throw
std
::
runtime_error
(
"KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."
);
}
if
(
new_batch_size
>
cache_positions
.
size
())
{
cache_positions
.
resize
(
new_batch_size
,
0
);
}
auto
k_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
auto
v_new
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
new_capacity
,
head_dim
},
auto
v_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
// Copy existing cache data
// Copy existing cache data
if
(
cache_position
>
0
)
{
for
(
size_t
b
=
0
;
b
<
new_batch_size
;
++
b
)
{
auto
k_slice
=
k_cache
->
narrow
({{
1
,
0
,
cache_position
}});
size_t
cache_position
=
cache_positions
[
b
];
auto
v_slice
=
v_cache
->
narrow
({{
1
,
0
,
cache_position
}});
if
(
cache_position
>
0
)
{
k_new
->
narrow
({{
1
,
0
,
cache_position
}})
->
copy_from
(
k_slice
);
auto
k_slice
=
k_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
v_new
->
narrow
({{
1
,
0
,
cache_position
}})
->
copy_from
(
v_slice
);
auto
v_slice
=
v_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
k_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
k_slice
);
v_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
v_slice
);
}
}
}
k_cache
=
k_new
;
k_cache
=
k_new
;
...
@@ -76,10 +84,16 @@ struct KVCache {
...
@@ -76,10 +84,16 @@ struct KVCache {
}
}
}
}
KVCache
(
size_t
max_batch_size
,
size_t
n_kv_head
,
size_t
head_dim
,
infinicore
::
DataType
dtype
,
size_t
max_seqlen
=
4096
,
infinicore
::
Device
device
=
infinicore
::
context
::
getDevice
())
:
max_capacity
(
max_seqlen
),
initialized
(
false
)
{
cache_positions
=
std
::
vector
<
size_t
>
(
max_batch_size
,
0
);
ensure_capacity
(
max_batch_size
,
n_kv_head
,
head_dim
,
max_capacity
,
dtype
,
device
);
}
/**
/**
* @brief Update cache with new key and value states
* @brief Update cache with new key and value states
* @param k_new New key states [n_kv_head, seq_len, head_dim]
* @param k_new New key states [
batch_size,
n_kv_head, seq_len, head_dim]
* @param v_new New value states [n_kv_head, seq_len, head_dim]
* @param v_new New value states [
batch_size,
n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
*
*
* Note: This method writes to the cache. If using with attention op, the attention op
* Note: This method writes to the cache. If using with attention op, the attention op
...
@@ -88,28 +102,42 @@ struct KVCache {
...
@@ -88,28 +102,42 @@ struct KVCache {
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
v_new
)
{
const
infinicore
::
Tensor
&
v_new
)
{
size_t
seq_len
=
k_new
->
shape
()[
1
];
if
(
k_new
->
ndim
()
!=
4
||
v_new
->
ndim
()
!=
4
)
{
size_t
num_kv_heads
=
k_new
->
shape
()[
0
];
throw
std
::
runtime_error
(
"KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form."
);
size_t
head_dim
=
k_new
->
shape
()[
2
];
}
size_t
batch_size
=
k_new
->
shape
()[
0
];
size_t
num_kv_heads
=
k_new
->
shape
()[
1
];
size_t
seq_len
=
k_new
->
shape
()[
2
];
size_t
head_dim
=
k_new
->
shape
()[
3
];
// Ensure capacity
// Ensure capacity
ensure_capacity
(
num_kv_heads
,
head_dim
,
seq_len
,
ensure_capacity
(
batch_size
,
num_kv_heads
,
head_dim
,
seq_len
,
k_new
->
dtype
(),
k_new
->
device
());
k_new
->
dtype
(),
k_new
->
device
());
// Copy new k/v into cache at current position
// Copy new k/v into cache at current position
auto
k_dst
=
k_cache
->
narrow
({{
1
,
cache_position
,
seq_len
}});
bool
all_equal
=
cache_positions
.
empty
()
||
std
::
equal
(
cache_positions
.
begin
()
+
1
,
cache_positions
.
end
(),
cache_positions
.
begin
());
auto
v_dst
=
v_cache
->
narrow
({{
1
,
cache_position
,
seq_len
}});
if
(
all_equal
)
{
k_dst
->
copy_from
(
k_new
);
auto
cache_position
=
cache_positions
[
0
];
v_dst
->
copy_from
(
v_new
);
// Update position
auto
k_dst
=
k_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
cache_position
+=
seq_len
;
auto
v_dst
=
v_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
k_dst
->
copy_from
(
k_new
);
v_dst
->
copy_from
(
v_new
);
// Return the total cache up to current position
// Update position
auto
k_total
=
k_cache
->
narrow
({{
1
,
0
,
cache_position
}});
cache_position
+=
seq_len
;
auto
v_total
=
v_cache
->
narrow
({{
1
,
0
,
cache_position
}});
for
(
size_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
cache_positions
[
b
]
=
cache_position
;
}
return
std
::
make_pair
(
k_total
->
contiguous
(),
v_total
->
contiguous
());
// Return the total cache up to current position
auto
k_total
=
k_cache
->
narrow
({{
2
,
0
,
cache_position
}});
auto
v_total
=
v_cache
->
narrow
({{
2
,
0
,
cache_position
}});
return
std
::
make_pair
(
k_total
,
v_total
);
}
else
{
throw
std
::
runtime_error
(
"KVCache update: cache positions must be equal among a batch."
);
}
}
}
};
};
...
...
csrc/models/llama/llama_attention.cpp
View file @
0da7b5db
...
@@ -72,15 +72,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -72,15 +72,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
}
// 4. Apply RoPE to full batch - align with Python pattern
// 4. Apply RoPE to full batch
auto
q_for_rope
=
q_reshaped
->
view
({
batch_size
*
seq_len
,
num_attention_heads_
,
head_dim_
});
// Python: x = x.view((bs * seq_len, num_heads, head_dim))
auto
k_for_rope
=
k_reshaped
->
view
({
batch_size
*
seq_len
,
num_key_value_heads_
,
head_dim_
});
// Python asserts: seq_len * x_stride[1] == x_stride[0] (contiguous in dim=0 and dim=1)
// The kernel requires stride(2) == 1 (last dimension contiguous)
// Python's assertion + stride(2) == 1 means the tensor is fully contiguous
// However, to be safe and match Python's behavior exactly, ensure fully contiguous
auto
q_for_rope
=
q_reshaped
->
view
({
batch_size
*
seq_len
,
num_attention_heads_
,
head_dim_
})
->
contiguous
();
auto
k_for_rope
=
k_reshaped
->
view
({
batch_size
*
seq_len
,
num_key_value_heads_
,
head_dim_
})
->
contiguous
();
// Call RoPE on full batch (matching Python pattern)
// Call RoPE on full batch (matching Python pattern)
auto
q_rope_out
=
rotary_emb_
->
forward
(
q_for_rope
,
pos_ids_for_rope
);
auto
q_rope_out
=
rotary_emb_
->
forward
(
q_for_rope
,
pos_ids_for_rope
);
...
@@ -92,108 +86,49 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -92,108 +86,49 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
// 5. Process each batch item separately for attention computation
// 5. Process each batch item separately for attention computation
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
auto
output_tensor
=
infinicore
::
Tensor
::
empty
(
{
batch_size
,
seq_len
,
hidden_size_
},
// Convert to [batch, n_head, seq_len, head_dim] for cache
q
->
dtype
(),
// Ensure contiguous after permute for F16 compatibility with cache operations
q
->
device
());
auto
q_rope
=
q_rope_out
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
();
// [bs, n_q_head, seq_len, head_dim]
auto
k_rope
=
k_rope_out
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
for
(
size_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
// Extract batch item from RoPE output (already computed above for full batch)
// Ensure contiguous after narrow+view to avoid stride issues in GEMM operations
// 5. Prepare KV caches
auto
q_batch
=
q_rope_out
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
auto
k_batch
=
k_rope_out
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
auto
v_batch
=
v_reshaped
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_rope
,
v_permuted
);
// Convert to [n_head, seq_len, head_dim] for cache
k_total
=
k_total_tmp
;
// Ensure contiguous after permute for F16 compatibility with cache operations
v_total
=
v_total_tmp
;
auto
q_rope
=
q_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_q_head, seq_len, head_dim]
}
else
{
auto
k_rope
=
k_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_kv_head, seq_len, head_dim]
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_rope
,
v_permuted
);
auto
v_permuted
=
v_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_kv_head, seq_len, head_dim]
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
// 5. Prepare KV caches
infinicore
::
Tensor
k_total
;
infinicore
::
Tensor
v_total
;
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_rope
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_rope
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
// 6. Compute attention - strictly align with Python pattern
// Python: query_states_i = query_states.narrow(0, i, 1).view((seq_len, num_attention_heads, head_dim))
// Python: key_states_i = key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: value_states_i = value_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: attention_i = grouped_query_attention(query_states_i, key_states_i, value_states_i, scaling=self.scaling)
// Extract from KV cache (k_total and v_total are [n_kv_head, total_seq_len, head_dim])
// Python: key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python's narrow+view ensures contiguous memory, so we need to ensure contiguous before permute
auto
k_for_attn
=
k_total
->
permute
({
1
,
0
,
2
});
// [total_seq_len, n_kv_head, head_dim]
auto
v_for_attn
=
v_total
->
permute
({
1
,
0
,
2
});
// [total_seq_len, n_kv_head, head_dim]
// q_batch is already [seq_len, n_q_head, head_dim] from above
auto
q_for_attn
=
q_batch
;
// [seq_len, n_q_head, head_dim]
// Python: grouped_query_attention calls repeat_kv if ngroup > 1
// Python: repeat_kv expands [total_seq_len, num_key_value_heads, head_dim] -> [total_seq_len, num_attention_heads, head_dim]
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
if
(
ngroup
>
1
)
{
// Python: repeat_kv uses as_strided to expand
size_t
total_seq_len
=
k_for_attn
->
shape
()[
0
];
size_t
n_kv_head
=
k_for_attn
->
shape
()[
1
];
size_t
head_dim
=
k_for_attn
->
shape
()[
2
];
auto
k_strides
=
k_for_attn
->
strides
();
auto
k_strided
=
k_for_attn
->
as_strided
(
{
total_seq_len
,
n_kv_head
,
ngroup
,
head_dim
},
{
k_strides
[
0
],
k_strides
[
1
],
0
,
k_strides
[
2
]});
k_for_attn
=
k_strided
->
contiguous
()
->
view
({
total_seq_len
,
n_kv_head
*
ngroup
,
head_dim
});
auto
v_strides
=
v_for_attn
->
strides
();
auto
v_strided
=
v_for_attn
->
as_strided
(
{
total_seq_len
,
n_kv_head
,
ngroup
,
head_dim
},
{
v_strides
[
0
],
v_strides
[
1
],
0
,
v_strides
[
2
]});
v_for_attn
=
v_strided
->
contiguous
()
->
view
({
total_seq_len
,
n_kv_head
*
ngroup
,
head_dim
});
}
// Python: multi_head_attention(querys, keys, values, scaling)
// Python: Q = querys.permute((1, 0, 2)) # [num_heads, seq_len, head_dim]
// Python: K = keys # [total_seq_len, num_heads, head_dim] (NO permute!)
// Python: V = values.permute((1, 0, 2)) # [num_heads, total_seq_len, head_dim]
auto
Q
=
q_for_attn
->
permute
({
1
,
0
,
2
});
// [n_q_head, seq_len, head_dim]
auto
K
=
k_for_attn
;
// [total_seq_len, n_q_head, head_dim] - keep as-is (matching Python)
auto
V
=
v_for_attn
->
permute
({
1
,
0
,
2
});
// [n_q_head, total_seq_len, head_dim]
// Python: attn_weight = Q @ K.permute((1, 2, 0))
// Python: K.permute((1, 2, 0)) transforms [total_seq_len, num_heads, head_dim] -> [num_heads, head_dim, total_seq_len]
auto
K_transposed
=
K
->
permute
({
1
,
2
,
0
});
// [n_q_head, head_dim, total_seq_len]
// Use GEMM with alpha=scaling to combine scaling with matrix multiplication
// This is more efficient than doing matmul followed by mul
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [n_q_head, seq_len, total_seq_len]
infinicore
::
op
::
causal_softmax_
(
attn_weight
,
attn_weight
);
auto
out
=
infinicore
::
op
::
matmul
(
attn_weight
,
V
);
// [n_q_head, seq_len, head_dim]
// Python: return out.permute((1, 0, 2)).contiguous() # [seq_len, num_heads, head_dim]
auto
attn_output
=
out
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [seq_len, n_q_head, head_dim]
// Python: attn_output_i.copy_(attention_i)
// Python: attn_output = attn_output.view(hidden_states_shape) # [bs, seq_len, hidden_size]
// Copy to output tensor - attn_output is [seq_len, num_attention_heads, head_dim]
auto
output_batch
=
output_tensor
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
hidden_size_
});
auto
attn_flat
=
attn_output
->
contiguous
()
->
view
({
seq_len
,
hidden_size_
});
output_batch
->
copy_from
(
attn_flat
);
}
}
auto
total_seq_len
=
k_total
->
shape
()[
2
];
// 6. Compute attention
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
auto
Q
=
q_rope
->
view
({
batch_size
*
num_key_value_heads_
,
ngroup
*
seq_len
,
head_dim_
});
auto
K
=
k_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
V
=
v_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
K_transposed
=
K
->
permute
({
0
,
2
,
1
});
// [bs * n_kv_head, head_dim, total_seq_len]
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
auto
attn_weight_softmax
=
attn_weight
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
total_seq_len
});
infinicore
::
op
::
causal_softmax_
(
attn_weight_softmax
,
attn_weight_softmax
);
auto
out
=
infinicore
::
op
::
matmul
(
attn_weight
,
V
);
// [bs * n_kv_head, ng * seq_len, head_dim]
auto
attn_output
=
out
->
view
({
batch_size
,
num_attention_heads_
,
seq_len
,
head_dim_
})
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
()
->
view
({
batch_size
,
seq_len
,
num_attention_heads_
*
head_dim_
});
// [bs, seq_len, n_q_head * head_dim]
// 8. Apply output projection to all batches
auto
output
=
o_proj_
->
forward
(
attn_output
);
auto
output
=
o_proj_
->
forward
(
output_tensor
);
return
output
;
return
output
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment