Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
36f8eab7
Unverified
Commit
36f8eab7
authored
Dec 04, 2025
by
PanZezhong1725
Committed by
GitHub
Dec 04, 2025
Browse files
Merge pull request #98 from InfiniTensor/issue/97
issue/97 Attention 和 KVCache 支持 batch 维度
parents
42f9d47d
6f624c94
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
161 additions
and
177 deletions
+161
-177
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+73
-43
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+46
-117
examples/llama.py
examples/llama.py
+29
-11
python/infinilm/generation/utils.py
python/infinilm/generation/utils.py
+13
-6
No files found.
csrc/cache/kv_cache.hpp
View file @
36f8eab7
#pragma once
#pragma once
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <numeric>
#include <stdexcept>
#include <utility>
#include <utility>
#include <spdlog/spdlog.h>
namespace
infinilm
::
cache
{
namespace
infinilm
::
cache
{
/**
/**
* @brief Simple KV cache structure for incremental decoding
* @brief Simple KV cache structure for incremental decoding
*
*
* Stores key and value caches with shape [n_kv_head, capacity, head_dim]
* Stores key and value caches with shape [
batch_size,
n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
* Similar to DynamicLayer in Python cache_utils.py
*
*
* This is a common component that can be used by any model architecture
* This is a common component that can be used by any model architecture
* that needs KV caching for attention mechanisms.
* that needs KV caching for attention mechanisms.
*/
*/
struct
KVCache
{
struct
KVCache
{
infinicore
::
Tensor
k_cache
;
// [n_kv_head, capacity, head_dim]
infinicore
::
Tensor
k_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
infinicore
::
Tensor
v_cache
;
// [n_kv_head, capacity, head_dim]
infinicore
::
Tensor
v_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
size_t
cache_position
;
// Current position in cache
std
::
vector
<
size_t
>
cache_positions
;
// Current position in cache
size_t
max_capacity
;
// Maximum capacity of cache
size_t
max_capacity
;
// Maximum capacity of cache
bool
initialized
;
// Whether cache has been initialized
bool
initialized
;
// Whether cache has been initialized
KVCache
()
KVCache
()
:
max_capacity
(
0
),
initialized
(
false
)
{}
:
cache_position
(
0
),
max_capacity
(
0
),
initialized
(
false
),
// Create empty placeholder tensors (will be replaced on first use)
k_cache
(
infinicore
::
Tensor
::
empty
({
1
,
1
,
1
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
(
infinicore
::
Device
::
Type
::
CPU
,
0
))),
v_cache
(
infinicore
::
Tensor
::
empty
({
1
,
1
,
1
},
infinicore
::
DataType
::
F32
,
infinicore
::
Device
(
infinicore
::
Device
::
Type
::
CPU
,
0
)))
{}
/**
/**
* @brief Initialize or update cache capacity
* @brief Initialize or update cache capacity
...
@@ -40,34 +40,44 @@ struct KVCache {
...
@@ -40,34 +40,44 @@ struct KVCache {
* @param dtype Data type
* @param dtype Data type
* @param device Device
* @param device Device
*/
*/
void
ensure_capacity
(
size_t
num_kv_heads
,
size_t
head_dim
,
size_t
seq_len
,
void
ensure_capacity
(
size_t
batch_size
,
size_t
num_kv_heads
,
size_t
head_dim
,
size_t
seq_len
,
infinicore
::
DataType
dtype
,
const
infinicore
::
Device
&
device
)
{
infinicore
::
DataType
dtype
,
const
infinicore
::
Device
&
device
)
{
size_t
required_capacity
=
cache_position
+
seq_len
;
size_t
required_capacity
=
seq_len
+
std
::
accumulate
(
cache_positions
.
begin
(),
cache_positions
.
end
(),
0
,
[](
int
a
,
int
b
)
{
return
std
::
max
(
a
,
b
);
})
;
// Lazy initialization
// Lazy initialization
if
(
!
initialized
)
{
if
(
!
initialized
)
{
max_capacity
=
std
::
max
(
required_capacity
,
size_t
(
4096
));
// Start with at least 4096
max_capacity
=
std
::
max
(
required_capacity
,
size_t
(
4096
));
// Start with at least 4096
k_cache
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
max_capacity
,
head_dim
},
k_cache
=
infinicore
::
Tensor
::
empty
({
batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
v_cache
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
max_capacity
,
head_dim
},
v_cache
=
infinicore
::
Tensor
::
empty
({
batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
cache_position
=
0
;
cache_position
s
=
std
::
vector
<
size_t
>
(
batch_size
,
0
)
;
initialized
=
true
;
initialized
=
true
;
}
}
// Grow cache if needed (similar to DynamicLayer in Python)
// Grow cache if needed (similar to DynamicLayer in Python)
else
if
(
required_capacity
>
max_capacity
)
{
else
if
(
required_capacity
>
max_capacity
)
{
size_t
new_capacity
=
std
::
max
(
max_capacity
*
2
,
required_capacity
);
size_t
new_capacity
=
std
::
max
(
max_capacity
*
2
,
required_capacity
+
max_capacity
);
auto
k_new
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
new_capacity
,
head_dim
},
size_t
new_batch_size
=
std
::
max
(
batch_size
,
k_cache
->
shape
()[
0
]);
if
(
num_kv_heads
!=
k_cache
->
shape
()[
1
]
||
head_dim
!=
k_cache
->
shape
()[
3
])
{
throw
std
::
runtime_error
(
"KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."
);
}
if
(
new_batch_size
>
cache_positions
.
size
())
{
cache_positions
.
resize
(
new_batch_size
,
0
);
}
auto
k_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
auto
v_new
=
infinicore
::
Tensor
::
empty
({
num_kv_heads
,
new_capacity
,
head_dim
},
auto
v_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
dtype
,
device
);
// Copy existing cache data
// Copy existing cache data
if
(
cache_position
>
0
)
{
for
(
size_t
b
=
0
;
b
<
new_batch_size
;
++
b
)
{
auto
k_slice
=
k_cache
->
narrow
({{
1
,
0
,
cache_position
}});
size_t
cache_position
=
cache_positions
[
b
];
auto
v_slice
=
v_cache
->
narrow
({{
1
,
0
,
cache_position
}});
if
(
cache_position
>
0
)
{
k_new
->
narrow
({{
1
,
0
,
cache_position
}})
->
copy_from
(
k_slice
);
auto
k_slice
=
k_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
v_new
->
narrow
({{
1
,
0
,
cache_position
}})
->
copy_from
(
v_slice
);
auto
v_slice
=
v_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
k_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
k_slice
);
v_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
v_slice
);
}
}
}
k_cache
=
k_new
;
k_cache
=
k_new
;
...
@@ -76,10 +86,16 @@ struct KVCache {
...
@@ -76,10 +86,16 @@ struct KVCache {
}
}
}
}
KVCache
(
size_t
max_batch_size
,
size_t
n_kv_head
,
size_t
head_dim
,
infinicore
::
DataType
dtype
,
size_t
max_seqlen
=
4096
,
infinicore
::
Device
device
=
infinicore
::
context
::
getDevice
())
:
max_capacity
(
max_seqlen
),
initialized
(
false
)
{
cache_positions
=
std
::
vector
<
size_t
>
(
max_batch_size
,
0
);
ensure_capacity
(
max_batch_size
,
n_kv_head
,
head_dim
,
max_capacity
,
dtype
,
device
);
}
/**
/**
* @brief Update cache with new key and value states
* @brief Update cache with new key and value states
* @param k_new New key states [n_kv_head, seq_len, head_dim]
* @param k_new New key states [
batch_size,
n_kv_head, seq_len, head_dim]
* @param v_new New value states [n_kv_head, seq_len, head_dim]
* @param v_new New value states [
batch_size,
n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
*
*
* Note: This method writes to the cache. If using with attention op, the attention op
* Note: This method writes to the cache. If using with attention op, the attention op
...
@@ -88,28 +104,42 @@ struct KVCache {
...
@@ -88,28 +104,42 @@ struct KVCache {
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
v_new
)
{
const
infinicore
::
Tensor
&
v_new
)
{
size_t
seq_len
=
k_new
->
shape
()[
1
];
if
(
k_new
->
ndim
()
!=
4
||
v_new
->
ndim
()
!=
4
)
{
size_t
num_kv_heads
=
k_new
->
shape
()[
0
];
throw
std
::
runtime_error
(
"KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form."
);
size_t
head_dim
=
k_new
->
shape
()[
2
];
}
size_t
batch_size
=
k_new
->
shape
()[
0
];
size_t
num_kv_heads
=
k_new
->
shape
()[
1
];
size_t
seq_len
=
k_new
->
shape
()[
2
];
size_t
head_dim
=
k_new
->
shape
()[
3
];
// Ensure capacity
// Ensure capacity
ensure_capacity
(
num_kv_heads
,
head_dim
,
seq_len
,
ensure_capacity
(
batch_size
,
num_kv_heads
,
head_dim
,
seq_len
,
k_new
->
dtype
(),
k_new
->
device
());
k_new
->
dtype
(),
k_new
->
device
());
// Copy new k/v into cache at current position
// Copy new k/v into cache at current position
auto
k_dst
=
k_cache
->
narrow
({{
1
,
cache_position
,
seq_len
}});
bool
all_equal
=
cache_positions
.
empty
()
||
std
::
equal
(
cache_positions
.
begin
()
+
1
,
cache_positions
.
end
(),
cache_positions
.
begin
());
auto
v_dst
=
v_cache
->
narrow
({{
1
,
cache_position
,
seq_len
}});
if
(
all_equal
)
{
k_dst
->
copy_from
(
k_new
);
auto
cache_position
=
cache_positions
[
0
];
v_dst
->
copy_from
(
v_new
);
// Update position
auto
k_dst
=
k_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
cache_position
+=
seq_len
;
auto
v_dst
=
v_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
k_dst
->
copy_from
(
k_new
);
v_dst
->
copy_from
(
v_new
);
// Return the total cache up to current position
// Update position
auto
k_total
=
k_cache
->
narrow
({{
1
,
0
,
cache_position
}});
cache_position
+=
seq_len
;
auto
v_total
=
v_cache
->
narrow
({{
1
,
0
,
cache_position
}});
for
(
size_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
cache_positions
[
b
]
=
cache_position
;
}
return
std
::
make_pair
(
k_total
->
contiguous
(),
v_total
->
contiguous
());
// Return the total cache up to current position
auto
k_total
=
k_cache
->
narrow
({{
2
,
0
,
cache_position
}});
auto
v_total
=
v_cache
->
narrow
({{
2
,
0
,
cache_position
}});
return
std
::
make_pair
(
k_total
,
v_total
);
}
else
{
throw
std
::
runtime_error
(
"KVCache update: cache positions must be equal among a batch."
);
}
}
}
};
};
...
...
csrc/models/llama/llama_attention.cpp
View file @
36f8eab7
...
@@ -72,128 +72,57 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -72,128 +72,57 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
}
// 4. Apply RoPE to full batch - align with Python pattern
// 4. Process each batch item separately for attention computation
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
// Python: x = x.view((bs * seq_len, num_heads, head_dim))
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Python asserts: seq_len * x_stride[1] == x_stride[0] (contiguous in dim=0 and dim=1)
// Ensure contiguous after permute for F16 compatibility with cache operations
// The kernel requires stride(2) == 1 (last dimension contiguous)
q_reshaped
=
q_reshaped
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
();
// [bs, n_q_head, seq_len, head_dim]
// Python's assertion + stride(2) == 1 means the tensor is fully contiguous
auto
k_permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
// However, to be safe and match Python's behavior exactly, ensure fully contiguous
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
q_for_rope
=
q_reshaped
->
view
({
batch_size
*
seq_len
,
num_attention_heads_
,
head_dim_
})
->
contiguous
();
auto
k_for_rope
=
k_reshaped
->
view
({
batch_size
*
seq_len
,
num_key_value_heads_
,
head_dim_
})
->
contiguous
();
// 4. Prepare KV caches
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_permuted
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_permuted
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
auto
total_seq_len
=
k_total
->
shape
()[
2
];
// Call RoPE on full batch (matching Python pattern)
// 5. Apply RoPE to full batch
auto
q_rope_out
=
rotary_emb_
->
forward
(
q_for_rope
,
pos_ids_for_rope
);
auto
q_rope
=
q_reshaped
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_q_head, head_dim]
auto
k_rope_out
=
rotary_emb_
->
forward
(
k_for_rope
,
pos_ids_for_rope
);
auto
k_rope
=
k_total
->
narrow
({{
2
,
total_seq_len
-
seq_len
,
seq_len
}})
->
view
({
batch_size
*
num_key_value_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_kv_head, head_dim]
rotary_emb_
->
forward
(
q_rope
,
pos_ids_for_rope
,
true
);
rotary_emb_
->
forward
(
k_rope
,
pos_ids_for_rope
,
true
);
// Reshape back to [batch_size, seq_len, num_heads, head_dim] (matching Python pattern)
// 6. Compute attention
q_rope_out
=
q_rope_out
->
view
({
batch_size
,
seq_len
,
num_attention_heads_
,
head_dim_
});
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
k_rope_out
=
k_rope_out
->
view
({
batch_size
,
seq_len
,
num_key_value_heads_
,
head_dim_
});
auto
Q
=
q_reshaped
->
view
({
batch_size
*
num_key_value_heads_
,
ngroup
*
seq_len
,
head_dim_
});
auto
K
=
k_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
V
=
v_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
// 5. Process each batch item separately for attention computation
auto
K_transposed
=
K
->
permute
({
0
,
2
,
1
});
// [bs * n_kv_head, head_dim, total_seq_len]
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
auto
output_tensor
=
infinicore
::
Tensor
::
empty
(
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
{
batch_size
,
seq_len
,
hidden_size_
},
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
q
->
dtype
(),
q
->
device
());
auto
attn_weight_softmax
=
attn_weight
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
total_seq_len
});
infinicore
::
op
::
causal_softmax_
(
attn_weight_softmax
,
attn_weight_softmax
);
for
(
size_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
// Extract batch item from RoPE output (already computed above for full batch)
auto
out
=
infinicore
::
op
::
matmul
(
attn_weight
,
V
);
// [bs * n_kv_head, ng * seq_len, head_dim]
// Ensure contiguous after narrow+view to avoid stride issues in GEMM operations
auto
q_batch
=
q_rope_out
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
auto
attn_output
=
out
->
view
({
batch_size
,
num_attention_heads_
,
seq_len
,
head_dim_
})
auto
k_batch
=
k_rope_out
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
->
permute
({
0
,
2
,
1
,
3
})
auto
v_batch
=
v_reshaped
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
->
contiguous
()
->
view
({
batch_size
,
seq_len
,
num_attention_heads_
*
head_dim_
});
// [bs, seq_len, n_q_head * head_dim]
// Convert to [n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
auto
q_rope
=
q_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_q_head, seq_len, head_dim]
auto
k_rope
=
k_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_batch
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [n_kv_head, seq_len, head_dim]
// 5. Prepare KV caches
infinicore
::
Tensor
k_total
;
infinicore
::
Tensor
v_total
;
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_rope
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_rope
,
v_permuted
);
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
// 6. Compute attention - strictly align with Python pattern
// Python: query_states_i = query_states.narrow(0, i, 1).view((seq_len, num_attention_heads, head_dim))
// Python: key_states_i = key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: value_states_i = value_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: attention_i = grouped_query_attention(query_states_i, key_states_i, value_states_i, scaling=self.scaling)
// Extract from KV cache (k_total and v_total are [n_kv_head, total_seq_len, head_dim])
// Python: key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python's narrow+view ensures contiguous memory, so we need to ensure contiguous before permute
auto
k_for_attn
=
k_total
->
permute
({
1
,
0
,
2
});
// [total_seq_len, n_kv_head, head_dim]
auto
v_for_attn
=
v_total
->
permute
({
1
,
0
,
2
});
// [total_seq_len, n_kv_head, head_dim]
// q_batch is already [seq_len, n_q_head, head_dim] from above
auto
q_for_attn
=
q_batch
;
// [seq_len, n_q_head, head_dim]
// Python: grouped_query_attention calls repeat_kv if ngroup > 1
// Python: repeat_kv expands [total_seq_len, num_key_value_heads, head_dim] -> [total_seq_len, num_attention_heads, head_dim]
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
if
(
ngroup
>
1
)
{
// Python: repeat_kv uses as_strided to expand
size_t
total_seq_len
=
k_for_attn
->
shape
()[
0
];
size_t
n_kv_head
=
k_for_attn
->
shape
()[
1
];
size_t
head_dim
=
k_for_attn
->
shape
()[
2
];
auto
k_strides
=
k_for_attn
->
strides
();
auto
k_strided
=
k_for_attn
->
as_strided
(
{
total_seq_len
,
n_kv_head
,
ngroup
,
head_dim
},
{
k_strides
[
0
],
k_strides
[
1
],
0
,
k_strides
[
2
]});
k_for_attn
=
k_strided
->
contiguous
()
->
view
({
total_seq_len
,
n_kv_head
*
ngroup
,
head_dim
});
auto
v_strides
=
v_for_attn
->
strides
();
auto
v_strided
=
v_for_attn
->
as_strided
(
{
total_seq_len
,
n_kv_head
,
ngroup
,
head_dim
},
{
v_strides
[
0
],
v_strides
[
1
],
0
,
v_strides
[
2
]});
v_for_attn
=
v_strided
->
contiguous
()
->
view
({
total_seq_len
,
n_kv_head
*
ngroup
,
head_dim
});
}
// Python: multi_head_attention(querys, keys, values, scaling)
// Python: Q = querys.permute((1, 0, 2)) # [num_heads, seq_len, head_dim]
// Python: K = keys # [total_seq_len, num_heads, head_dim] (NO permute!)
// Python: V = values.permute((1, 0, 2)) # [num_heads, total_seq_len, head_dim]
auto
Q
=
q_for_attn
->
permute
({
1
,
0
,
2
});
// [n_q_head, seq_len, head_dim]
auto
K
=
k_for_attn
;
// [total_seq_len, n_q_head, head_dim] - keep as-is (matching Python)
auto
V
=
v_for_attn
->
permute
({
1
,
0
,
2
});
// [n_q_head, total_seq_len, head_dim]
// Python: attn_weight = Q @ K.permute((1, 2, 0))
// Python: K.permute((1, 2, 0)) transforms [total_seq_len, num_heads, head_dim] -> [num_heads, head_dim, total_seq_len]
auto
K_transposed
=
K
->
permute
({
1
,
2
,
0
});
// [n_q_head, head_dim, total_seq_len]
// Use GEMM with alpha=scaling to combine scaling with matrix multiplication
// This is more efficient than doing matmul followed by mul
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [n_q_head, seq_len, total_seq_len]
infinicore
::
op
::
causal_softmax_
(
attn_weight
,
attn_weight
);
auto
out
=
infinicore
::
op
::
matmul
(
attn_weight
,
V
);
// [n_q_head, seq_len, head_dim]
// Python: return out.permute((1, 0, 2)).contiguous() # [seq_len, num_heads, head_dim]
auto
attn_output
=
out
->
permute
({
1
,
0
,
2
})
->
contiguous
();
// [seq_len, n_q_head, head_dim]
// Python: attn_output_i.copy_(attention_i)
// Python: attn_output = attn_output.view(hidden_states_shape) # [bs, seq_len, hidden_size]
// Copy to output tensor - attn_output is [seq_len, num_attention_heads, head_dim]
auto
output_batch
=
output_tensor
->
narrow
({{
0
,
b
,
1
}})
->
view
({
seq_len
,
hidden_size_
});
auto
attn_flat
=
attn_output
->
contiguous
()
->
view
({
seq_len
,
hidden_size_
});
output_batch
->
copy_from
(
attn_flat
);
}
// 8. Apply output projection to all batches
auto
output
=
o_proj_
->
forward
(
attn_output
);
auto
output
=
o_proj_
->
forward
(
output_tensor
);
return
output
;
return
output
;
}
}
...
...
examples/llama.py
View file @
36f8eab7
...
@@ -63,11 +63,23 @@ def get_args():
...
@@ -63,11 +63,23 @@ def get_args():
default
=
"float32"
,
default
=
"float32"
,
help
=
"float32, float16, bfloat16"
,
help
=
"float32, float16, bfloat16"
,
)
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"number of prompts in a batch"
,
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"How are you"
,
help
=
"input prompt"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
test
(
def
test
(
prompt
,
prompt
s
:
str
|
list
[
str
]
,
model_path
,
model_path
,
max_new_tokens
=
100
,
max_new_tokens
=
100
,
infini_dtype
=
infinicore
.
bfloat16
,
infini_dtype
=
infinicore
.
bfloat16
,
...
@@ -123,18 +135,24 @@ def test(
...
@@ -123,18 +135,24 @@ def test(
# token编码
# token编码
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?"
# prompt = "山东最高的山是?"
input_content
=
tokenizer
.
apply_chat_template
(
if
isinstance
(
prompts
,
str
):
conversation
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
prompts
=
[
prompts
]
add_generation_prompt
=
True
,
input_contents
=
[
tokenize
=
False
,
tokenizer
.
apply_chat_template
(
)
conversation
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
print
(
input_content
,
end
=
""
,
flush
=
True
)
add_generation_prompt
=
True
,
input_ids
=
tokenizer
.
encode
(
input_content
)
tokenize
=
False
,
)
for
prompt
in
prompts
]
print
(
input_contents
[
0
],
end
=
""
,
flush
=
True
)
input_ids_list
=
tokenizer
.
batch_encode_plus
(
input_contents
)[
"input_ids"
]
# List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# 自回归生成
# 自回归生成
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
input_ids_list
=
[
input_ids
]
# List: [[1, 1128, 526, 366, 29892]]
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
t1
=
time
.
time
()
t1
=
time
.
time
()
...
@@ -175,7 +193,7 @@ if __name__ == "__main__":
...
@@ -175,7 +193,7 @@ if __name__ == "__main__":
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
)
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
prompt
=
"How are you"
prompt
s
=
[
args
.
prompt
for
_
in
range
(
args
.
batch_size
)]
model_path
=
args
.
model_path
model_path
=
args
.
model_path
max_new_tokens
=
args
.
max_new_tokens
max_new_tokens
=
args
.
max_new_tokens
...
@@ -192,7 +210,7 @@ if __name__ == "__main__":
...
@@ -192,7 +210,7 @@ if __name__ == "__main__":
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
"
)
test
(
test
(
prompt
,
prompt
s
,
model_path
,
model_path
,
max_new_tokens
,
max_new_tokens
,
infini_device
=
infini_device
,
infini_device
=
infini_device
,
...
...
python/infinilm/generation/utils.py
View file @
36f8eab7
...
@@ -100,9 +100,11 @@ class GenerationMixin:
...
@@ -100,9 +100,11 @@ class GenerationMixin:
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
# 所需的: token的input_ids
# 所需的: token的input_ids
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
if
kwargs
.
get
(
"next_token_id"
,
None
)
is
not
None
:
if
kwargs
.
get
(
"next_token_ids"
,
None
)
is
not
None
:
next_token_id
=
kwargs
[
"next_token_id"
]
next_token_ids
=
kwargs
[
"next_token_ids"
]
model_inputs
[
"input_ids"
]
=
infinicore
.
from_list
([[
next_token_id
]])
model_inputs
[
"input_ids"
]
=
infinicore
.
from_list
(
[[
id_
]
for
id_
in
next_token_ids
],
)
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
# 其他
# 其他
...
@@ -236,7 +238,7 @@ class GenerationMixin:
...
@@ -236,7 +238,7 @@ class GenerationMixin:
token_id
=
next_tokens
.
to_numpy
()[
0
]
token_id
=
next_tokens
.
to_numpy
()[
0
]
output_str
=
tokenizer
.
decode
([
token_id
],
skip_special_tokens
=
True
)
output_str
=
tokenizer
.
decode
([
token_id
],
skip_special_tokens
=
True
)
model_kwargs
[
"next_token_id"
]
=
token_id
model_kwargs
[
"next_token_id
s
"
]
=
next_tokens
.
to_numpy
().
tolist
()
output_tokens_list
.
append
(
token_id
)
output_tokens_list
.
append
(
token_id
)
output_content
+=
output_str
output_content
+=
output_str
...
@@ -245,11 +247,16 @@ class GenerationMixin:
...
@@ -245,11 +247,16 @@ class GenerationMixin:
break
break
print
(
"
\n
</s>"
)
print
(
"
\n
</s>"
)
print
(
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_list
),
2
)
}
ms"
)
print
(
print
(
f
"
\n\n\n
Time per step: prefill
{
round
(
time_list
[
0
],
2
)
}
ms/token
\n
"
,
f
"
Batchsize=
{
batch_size
}
Per_Batch_Input_Len=
{
seq_len
}
Per_Batch_New_Tokens=
{
len
(
time_list
)
}
\n
"
)
)
print
(
print
(
f
"
Time per step: decoder
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms/token
\n
"
,
f
"
Prefill TTFT:
{
round
(
time_list
[
0
],
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
seq_
len
)
/
time_list
[
0
],
2
)
}
tok/s
\n
"
,
)
)
if
len
(
time_list
)
>
1
:
print
(
f
" Decode Avg ITL:
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
(
len
(
time_list
)
-
1
))
/
sum
(
time_list
[
1
:]),
2
)
}
tok/s
\n
"
,
)
return
output_tokens_list
,
output_content
return
output_tokens_list
,
output_content
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment