Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
ff00b5c8
Commit
ff00b5c8
authored
Dec 23, 2025
by
PanZezhong
Browse files
issue/125 统一Cache接口
parent
13a4154a
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
631 deletions
+305
-631
csrc/cache/base_cache.hpp
csrc/cache/base_cache.hpp
+19
-0
csrc/cache/cache.hpp
csrc/cache/cache.hpp
+1
-1
csrc/cache/cache_config.hpp
csrc/cache/cache_config.hpp
+0
-47
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+113
-0
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+54
-338
csrc/engine/distributed/communication_group.cpp
csrc/engine/distributed/communication_group.cpp
+1
-1
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+12
-42
csrc/engine/infer_engine.hpp
csrc/engine/infer_engine.hpp
+9
-9
csrc/engine/rank_worker.cpp
csrc/engine/rank_worker.cpp
+30
-67
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+4
-9
csrc/models/infinilm_model.hpp
csrc/models/infinilm_model.hpp
+4
-7
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+8
-8
csrc/models/llama/llama_attention.hpp
csrc/models/llama/llama_attention.hpp
+2
-1
csrc/models/llama/llama_decoder_layer.cpp
csrc/models/llama/llama_decoder_layer.cpp
+3
-2
csrc/models/llama/llama_decoder_layer.hpp
csrc/models/llama/llama_decoder_layer.hpp
+2
-1
csrc/models/llama/llama_for_causal_lm.cpp
csrc/models/llama/llama_for_causal_lm.cpp
+4
-8
csrc/models/llama/llama_for_causal_lm.hpp
csrc/models/llama/llama_for_causal_lm.hpp
+1
-3
csrc/models/llama/llama_model.cpp
csrc/models/llama/llama_model.cpp
+21
-48
csrc/models/llama/llama_model.hpp
csrc/models/llama/llama_model.hpp
+8
-31
csrc/models/model_factory.cpp
csrc/models/model_factory.cpp
+9
-8
No files found.
csrc/cache/base_cache.hpp
0 → 100644
View file @
ff00b5c8
#pragma once
#include "../engine/distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
namespace
infinilm
::
cache
{
class
Cache
{
public:
Cache
()
=
default
;
virtual
~
Cache
()
{}
};
class
CacheConfig
{
public:
CacheConfig
()
=
default
;
virtual
~
CacheConfig
()
{}
virtual
std
::
unique_ptr
<
CacheConfig
>
unique_copy
()
const
=
0
;
};
}
// namespace infinilm::cache
csrc/cache/cache.hpp
View file @
ff00b5c8
#pragma once
#pragma once
#include "cache
_config
.hpp"
#include "
base_
cache.hpp"
#include "kv_cache.hpp"
#include "kv_cache.hpp"
csrc/cache/cache_config.hpp
deleted
100644 → 0
View file @
13a4154a
#pragma once
#include <cstddef>
#include <string>
#include <cstdint>
namespace
infinilm
::
cache
{
/**
* @enum CacheType
* @brief Enumeration of supported cache types
*/
enum
class
CacheType
{
DYNAMIC
,
///< Dynamic KV cache (grows as needed)
PAGED
,
///< Paged KV cache (for paged attention)
};
enum
class
CacheResetMode
{
PRESERVE
,
// Keep cache memory, only reset positions
RECREATE
// Recreate cache with new configuration
};
struct
CacheConfig
{
CacheType
type
=
CacheType
::
DYNAMIC
;
size_t
num_layers
=
0
;
size_t
max_kv_cache_length
=
SIZE_MAX
;
size_t
initial_capacity
=
1024
;
// Initial cache capacity in tokens
size_t
initial_batch_size
=
1
;
// Initial batch size for cache allocation
float
growth_factor
=
2.0
f
;
// Cache growth factor when resizing
bool
allow_expand
=
true
;
// Whether to allow cache expansion
CacheResetMode
reset_mode
=
CacheResetMode
::
PRESERVE
;
// Constructor
CacheConfig
()
=
default
;
CacheConfig
(
CacheType
type
,
size_t
num_layers
=
32
,
size_t
max_kv_cache_length
=
4096
)
:
type
(
type
),
num_layers
(
num_layers
),
max_kv_cache_length
(
max_kv_cache_length
)
{}
bool
operator
==
(
const
CacheConfig
&
other
)
const
{
return
type
==
other
.
type
&&
num_layers
==
other
.
num_layers
&&
max_kv_cache_length
==
other
.
max_kv_cache_length
&&
initial_capacity
==
other
.
initial_capacity
&&
initial_batch_size
==
other
.
initial_batch_size
&&
growth_factor
==
other
.
growth_factor
;
}
bool
operator
!=
(
const
CacheConfig
&
other
)
const
{
return
!
(
*
this
==
other
);
}
};
}
// namespace infinilm::cache
csrc/cache/kv_cache.cpp
0 → 100644
View file @
ff00b5c8
#include "kv_cache.hpp"
#include "../utils.hpp"
#include <stdexcept>
namespace
infinilm
::
cache
{
// ==========================
// StaticKVCache
// ==========================
StaticKVCache
::
StaticKVCache
(
infinicore
::
Size
k_dim
,
infinicore
::
Size
v_dim
,
infinicore
::
Size
num_k_heads
,
infinicore
::
Size
num_v_heads
,
infinicore
::
Size
num_layers
,
infinicore
::
Size
max_positional_embedding
,
infinicore
::
DataType
dtype
,
const
StaticKVCacheConfig
&
config
,
const
engine
::
distributed
::
RankInfo
&
rank_info
)
:
Cache
(),
k_dim_
(
k_dim
),
v_dim_
(
v_dim
),
num_rank_k_heads_
(
num_k_heads
/
rank_info
.
tp_size
),
num_rank_v_heads_
(
num_v_heads
/
rank_info
.
tp_size
),
rank_batch_size_
(
config
.
max_batch_size
()),
cache_len_
(
std
::
min
(
config
.
max_cache_len
(),
max_positional_embedding
)),
rank_num_layers_
(
num_layers
),
dtype_
(
dtype
)
{
// Allocate K cache
k_caches_
=
infinicore
::
Tensor
::
empty
(
{
rank_num_layers_
,
rank_batch_size_
,
num_rank_k_heads_
,
cache_len_
,
k_dim_
},
dtype_
,
rank_info
.
device
);
// Allocate V cache
v_caches_
=
infinicore
::
Tensor
::
empty
(
{
rank_num_layers_
,
rank_batch_size_
,
num_rank_v_heads_
,
cache_len_
,
v_dim_
},
dtype_
,
rank_info
.
device
);
spdlog
::
info
(
"Created Static KV Cache: K[{}] V[{}]"
,
k_caches_
->
info
(),
v_caches_
->
info
());
}
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
StaticKVCache
::
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
cache_positions
)
{
ASSERT
(
layer_idx
<
rank_num_layers_
);
auto
batch_size
=
k
->
size
(
0
);
auto
update_len
=
k
->
size
(
2
);
size_t
cache_pos
=
reinterpret_cast
<
int64_t
*>
(
cache_positions
->
to
(
infinicore
::
Device
::
cpu
())
->
data
())[
0
];
auto
result_len
=
cache_pos
+
update_len
;
ASSERT
(
result_len
<=
cache_len_
);
ASSERT_EQ
(
batch_size
,
rank_batch_size_
);
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
k_cache_update
=
k_cache_layer
->
narrow
({{
2
,
cache_pos
,
update_len
}});
auto
v_cache_update
=
v_cache_layer
->
narrow
({{
2
,
cache_pos
,
update_len
}});
k_cache_update
->
copy_from
(
k
);
v_cache_update
->
copy_from
(
v
);
auto
k_total
=
k_cache_layer
->
narrow
({{
2
,
0
,
result_len
}});
auto
v_total
=
v_cache_layer
->
narrow
({{
2
,
0
,
result_len
}});
return
{
k_total
,
v_total
};
}
// ==========================
// StaticKVCacheConfig
// ==========================
StaticKVCacheConfig
::
StaticKVCacheConfig
(
infinicore
::
Size
_max_batch_size
,
infinicore
::
Size
_max_cache_len
)
:
max_batch_size_
(
_max_batch_size
),
max_cache_len_
(
_max_cache_len
)
{
}
std
::
unique_ptr
<
CacheConfig
>
StaticKVCacheConfig
::
unique_copy
()
const
{
return
std
::
make_unique
<
StaticKVCacheConfig
>
(
*
this
);
}
infinicore
::
Size
StaticKVCacheConfig
::
max_batch_size
()
const
{
return
max_batch_size_
;
}
infinicore
::
Size
StaticKVCacheConfig
::
max_cache_len
()
const
{
return
max_cache_len_
;
}
}
// namespace infinilm::cache
csrc/cache/kv_cache.hpp
View file @
ff00b5c8
#pragma once
#pragma once
#include "base_cache.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
#include "cache_config.hpp"
#include <algorithm>
#include <algorithm>
#include <limits>
#include <memory>
#include <memory>
#include <numeric>
#include <numeric>
#include <stdexcept>
#include <stdexcept>
...
@@ -15,355 +16,70 @@
...
@@ -15,355 +16,70 @@
#include <spdlog/spdlog.h>
#include <spdlog/spdlog.h>
namespace
infinilm
::
cache
{
namespace
infinilm
::
cache
{
class
StaticKVCacheConfig
final
:
public
CacheConfig
{
public:
StaticKVCacheConfig
(
infinicore
::
Size
_max_batch_size
=
1
,
infinicore
::
Size
_max_cache_len
=
std
::
numeric_limits
<
infinicore
::
Size
>::
max
());
/**
std
::
unique_ptr
<
CacheConfig
>
unique_copy
()
const
override
;
* @brief Single layer's KV cache for incremental decoding
infinicore
::
Size
max_batch_size
()
const
;
*
infinicore
::
Size
max_cache_len
()
const
;
* Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py
*
* This represents a single layer's cache within a model-level cache container.
*/
struct
KVCacheLayer
{
infinicore
::
Tensor
k_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
infinicore
::
Tensor
v_cache
;
// [batch_size, n_kv_head, capacity, head_dim]
std
::
vector
<
size_t
>
cache_positions
;
// Current position in cache
size_t
max_capacity
;
// Maximum capacity of cache
size_t
initial_capacity
;
// Initial capacity from config
size_t
initial_batch_size
;
// Initial batch size from config
float
growth_factor
;
// Growth factor for dynamic resizing
bool
initialized
;
// Whether cache has been initialized
KVCacheLayer
()
:
max_capacity
(
0
),
initial_capacity
(
4096
),
initial_batch_size
(
1
),
growth_factor
(
2.0
f
),
initialized
(
false
)
{}
/**
* @brief Initialize or update cache capacity with config parameters
* @param batch_size Current batch size
* @param num_kv_heads Number of key-value heads
* @param head_dim Head dimension
* @param seq_len Sequence length of new tokens
* @param dtype Data type
* @param device Device
* @param cache_config Cache configuration parameters
*/
void
ensure_capacity
(
size_t
batch_size
,
size_t
num_kv_heads
,
size_t
head_dim
,
size_t
seq_len
,
infinicore
::
DataType
dtype
,
const
infinicore
::
Device
&
device
,
const
CacheConfig
&
cache_config
)
{
size_t
required_capacity
=
seq_len
+
std
::
accumulate
(
cache_positions
.
begin
(),
cache_positions
.
end
(),
0
,
[](
int
a
,
int
b
)
{
return
std
::
max
(
a
,
b
);
});
// VALIDATION: Verify input parameters
if
(
num_kv_heads
==
0
||
head_dim
==
0
||
seq_len
==
0
)
{
SPDLOG_ERROR
(
"KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}"
,
num_kv_heads
,
head_dim
,
seq_len
);
throw
std
::
runtime_error
(
"KV cache ensure_capacity: invalid parameters"
);
}
// Store config parameters on first initialization
if
(
!
initialized
)
{
initial_capacity
=
cache_config
.
initial_capacity
;
initial_batch_size
=
cache_config
.
initial_batch_size
;
growth_factor
=
cache_config
.
growth_factor
;
}
// Lazy initialization
if
(
!
initialized
)
{
// Use max of required capacity and initial capacity from config
max_capacity
=
std
::
max
(
required_capacity
,
initial_capacity
);
// Use max of current batch size and initial batch size from config
size_t
alloc_batch_size
=
std
::
max
(
batch_size
,
initial_batch_size
);
k_cache
=
infinicore
::
Tensor
::
empty
({
alloc_batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
v_cache
=
infinicore
::
Tensor
::
empty
({
alloc_batch_size
,
num_kv_heads
,
max_capacity
,
head_dim
},
dtype
,
device
);
cache_positions
=
std
::
vector
<
size_t
>
(
alloc_batch_size
,
0
);
initialized
=
true
;
spdlog
::
debug
(
"Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})"
,
alloc_batch_size
,
max_capacity
,
initial_batch_size
,
initial_capacity
);
// VALIDATION: Verify cache was created correctly
if
(
k_cache
->
shape
()[
0
]
!=
alloc_batch_size
||
k_cache
->
shape
()[
1
]
!=
num_kv_heads
||
k_cache
->
shape
()[
2
]
!=
max_capacity
||
k_cache
->
shape
()[
3
]
!=
head_dim
)
{
SPDLOG_ERROR
(
"KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization"
);
throw
std
::
runtime_error
(
"KV cache initialization: shape mismatch"
);
}
}
// Grow cache if needed using growth factor from config
else
if
(
required_capacity
>
max_capacity
)
{
if
(
!
cache_config
.
allow_expand
)
{
SPDLOG_ERROR
(
"KVCacheLayer::ensure_capacity: Cache expansion not allowed by config"
);
throw
std
::
runtime_error
(
"KV cache expansion not allowed"
);
}
// Calculate new capacity using growth factor
size_t
new_capacity
=
static_cast
<
size_t
>
(
std
::
max
(
static_cast
<
float
>
(
max_capacity
)
*
growth_factor
,
static_cast
<
float
>
(
required_capacity
+
max_capacity
)));
// Ensure we don't exceed max_position_embeddings if specified
if
(
cache_config
.
max_kv_cache_length
!=
0
)
{
new_capacity
=
std
::
min
(
new_capacity
,
cache_config
.
max_kv_cache_length
);
}
// Ensure we grow by at least some minimum amount
size_t
min_growth
=
256
;
if
(
new_capacity
-
max_capacity
<
min_growth
)
{
new_capacity
=
max_capacity
+
min_growth
;
}
size_t
new_batch_size
=
std
::
max
(
batch_size
,
k_cache
->
shape
()[
0
]);
if
(
num_kv_heads
!=
k_cache
->
shape
()[
1
]
||
head_dim
!=
k_cache
->
shape
()[
3
])
{
throw
std
::
runtime_error
(
"KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache."
);
}
if
(
new_batch_size
>
cache_positions
.
size
())
{
cache_positions
.
resize
(
new_batch_size
,
0
);
}
auto
k_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
auto
v_new
=
infinicore
::
Tensor
::
empty
({
new_batch_size
,
num_kv_heads
,
new_capacity
,
head_dim
},
dtype
,
device
);
spdlog
::
debug
(
"Growing KV cache from capacity {} to {} (growth_factor={})"
,
max_capacity
,
new_capacity
,
growth_factor
);
// Copy existing cache data
for
(
size_t
b
=
0
;
b
<
new_batch_size
;
++
b
)
{
size_t
cache_position
=
cache_positions
[
b
];
if
(
cache_position
>
0
)
{
auto
k_slice
=
k_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
auto
v_slice
=
v_cache
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}});
k_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
k_slice
);
v_new
->
narrow
({{
0
,
b
,
1
},
{
2
,
0
,
cache_position
}})
->
copy_from
(
v_slice
);
}
}
k_cache
=
k_new
;
v_cache
=
v_new
;
max_capacity
=
new_capacity
;
// VALIDATION: Verify cache was grown correctly
if
(
k_cache
->
shape
()[
2
]
!=
new_capacity
)
{
SPDLOG_ERROR
(
"KVCacheLayer::ensure_capacity: New cache capacity mismatch"
);
throw
std
::
runtime_error
(
"KV cache growth: capacity mismatch"
);
}
}
// VALIDATION: Final check that capacity is sufficient
if
(
required_capacity
>
max_capacity
)
{
SPDLOG_ERROR
(
"KVCacheLayer::ensure_capacity: Capacity still insufficient after growth"
);
throw
std
::
runtime_error
(
"KV cache ensure_capacity: capacity insufficient"
);
}
}
/**
* @brief Update cache with new key and value states
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @param cache_config Cache configuration for capacity management
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
*/
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
v_new
,
const
CacheConfig
&
cache_config
)
{
if
(
k_new
->
ndim
()
!=
4
||
v_new
->
ndim
()
!=
4
)
{
throw
std
::
runtime_error
(
"KVCache update: k_new and v_new must be 4D tensors"
);
}
size_t
batch_size
=
k_new
->
shape
()[
0
];
size_t
num_kv_heads
=
k_new
->
shape
()[
1
];
size_t
seq_len
=
k_new
->
shape
()[
2
];
size_t
head_dim
=
k_new
->
shape
()[
3
];
// Ensure capacity with cache config
ensure_capacity
(
batch_size
,
num_kv_heads
,
head_dim
,
seq_len
,
k_new
->
dtype
(),
k_new
->
device
(),
cache_config
);
// Copy new k/v into cache at current position
bool
all_equal
=
cache_positions
.
empty
()
||
std
::
equal
(
cache_positions
.
begin
()
+
1
,
cache_positions
.
end
(),
cache_positions
.
begin
());
if
(
all_equal
)
{
auto
cache_position
=
cache_positions
[
0
];
auto
k_dst
=
k_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
auto
v_dst
=
v_cache
->
narrow
({{
2
,
cache_position
,
seq_len
}});
k_dst
->
copy_from
(
k_new
);
v_dst
->
copy_from
(
v_new
);
// Update position
cache_position
+=
seq_len
;
for
(
size_t
b
=
0
;
b
<
batch_size
;
++
b
)
{
cache_positions
[
b
]
=
cache_position
;
}
// Return the total cache up to current position
auto
k_total
=
k_cache
->
narrow
({{
2
,
0
,
cache_position
}});
auto
v_total
=
v_cache
->
narrow
({{
2
,
0
,
cache_position
}});
return
std
::
make_pair
(
k_total
,
v_total
);
private:
}
else
{
infinicore
::
Size
max_batch_size_
;
throw
std
::
runtime_error
(
"KVCache update: cache positions must be equal among a batch."
);
infinicore
::
Size
max_cache_len_
;
}
}
};
};
/**
class
StaticKVCache
final
:
public
Cache
{
* @brief Model-level KV cache container (similar to DynamicCache in Python)
*
* Stores a list of KVCacheLayer objects, one per model layer.
* This aligns with Python backend's DynamicCache architecture.
*/
class
DynamicCache
{
public:
public:
/**
StaticKVCache
(
* @brief Construct DynamicCache with cache configuration
* @param cache_config Cache configuration parameters
infinicore
::
Size
k_dim
,
*/
infinicore
::
Size
v_dim
,
DynamicCache
(
const
CacheConfig
&
cache_config
)
infinicore
::
Size
num_k_heads
,
:
cache_config_
(
cache_config
),
layers_
(
cache_config
.
num_layers
)
{
infinicore
::
Size
num_v_heads
,
if
(
cache_config
.
num_layers
==
-
1
)
{
infinicore
::
Size
num_layers
,
throw
std
::
runtime_error
(
"DynamicCache: num_layers must be specified in CacheConfig"
);
infinicore
::
Size
max_positional_embedding
,
}
infinicore
::
DataType
dtype
,
}
const
StaticKVCacheConfig
&
config
,
const
engine
::
distributed
::
RankInfo
&
rank_info
);
/**
/**
* @brief
Construct DynamicCache with specified number of layers
* @brief
Update KV cache at a given layer and cache position.
*
*
* @param num_layers Number of model layers (creates one cache layer per model layer)
* @param layer_idx Which transformer layer
* @param max_position_embeddings Maximum position embeddings (used for initial capacity)
* @param k [batch, num_rank_k_heads, seq_len, k_dim]
*/
* @param v [batch, num_rank_v_heads, seq_len, v_dim]
DynamicCache
(
size_t
num_layers
,
size_t
max_position_embeddings
=
4096
)
* @param cache_pos Sequence position to write
:
cache_config_
(
CacheConfig
(
CacheType
::
DYNAMIC
,
num_layers
,
max_position_embeddings
)),
layers_
(
num_layers
)
{}
/**
* @brief Update cache with new key and value states for a specific layer
*/
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
v_new
)
{
if
(
layer_idx
>=
layers_
.
size
())
{
SPDLOG_ERROR
(
"DynamicCache::update: layer_idx {} out of range (num_layers: {})"
,
layer_idx
,
layers_
.
size
());
throw
std
::
runtime_error
(
"DynamicCache: layer_idx out of range"
);
}
// Update the cache for this layer with cache config
return
layers_
[
layer_idx
].
update
(
k_new
,
v_new
,
cache_config_
);
}
/**
* @brief Update cache with new key and value states (convenience method without layer_idx)
* This is used when the cache is accessed directly without layer information
*
*
* @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @return (full_k, full_v)
* @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* full_k: [batch, num_rank_k_heads, cache_pos + seq_len, k_dim]
* @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
* full_v: [batch, num_rank_v_heads, cache_pos + seq_len, v_dim]
*
* Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead.
*/
std
::
pair
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
update
(
const
infinicore
::
Tensor
&
k_new
,
const
infinicore
::
Tensor
&
v_new
)
{
return
update
(
0
,
k_new
,
v_new
);
}
/**
* @brief Get cache configuration
*/
const
CacheConfig
&
get_config
()
const
{
return
cache_config_
;
}
/**
* @brief Update cache configuration (for dynamic reconfiguration)
*/
void
update_config
(
const
CacheConfig
&
new_config
)
{
// Check if we need to rebuild
bool
need_rebuild
=
false
;
// Rebuild if number of layers changed
if
(
new_config
.
num_layers
!=
cache_config_
.
num_layers
||
new_config
.
initial_batch_size
!=
cache_config_
.
initial_batch_size
)
{
need_rebuild
=
true
;
layers_
.
resize
(
new_config
.
num_layers
);
}
// Rebuild if reset mode is RECREATE
if
(
new_config
.
reset_mode
==
CacheResetMode
::
RECREATE
)
{
need_rebuild
=
true
;
}
// Update configuration
cache_config_
=
new_config
;
if
(
need_rebuild
)
{
// Clear all layers to force reinitialization on next use
for
(
auto
&
layer
:
layers_
)
{
layer
.
initialized
=
false
;
layer
.
max_capacity
=
0
;
// Tensors will be recreated when ensure_capacity is called
}
spdlog
::
info
(
"DynamicCache configuration updated - cache will be rebuilt on next use"
);
}
else
{
spdlog
::
info
(
"DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}"
,
new_config
.
num_layers
,
new_config
.
initial_capacity
,
new_config
.
growth_factor
);
}
}
/**
* @brief Get the number of layers in this cache
*/
size_t
num_layers
()
const
{
return
layers_
.
size
();
}
/**
* @brief Get cache position for a specific layer
*/
size_t
cache_position
(
size_t
layer_idx
)
const
{
if
(
layer_idx
>=
layers_
.
size
())
{
throw
std
::
runtime_error
(
"DynamicCache: layer_idx out of range"
);
}
if
(
layers_
[
layer_idx
].
cache_positions
.
empty
())
{
return
0
;
}
return
layers_
[
layer_idx
].
cache_positions
[
0
];
// All batch items should have same position
}
/**
* @brief Get max position embeddings (used for initial capacity)
*/
size_t
max_kv_cache_length
()
const
{
return
cache_config_
.
max_kv_cache_length
;
}
/**
* @brief Reset cache for all layers to a specific position
* This should be called when starting a new generation sequence or resetting to a specific position
* @param pos Position to reset to (defaults to 0)
*/
void
reset
(
size_t
pos
=
0
)
{
for
(
auto
&
layer
:
layers_
)
{
std
::
fill
(
layer
.
cache_positions
.
begin
(),
layer
.
cache_positions
.
end
(),
pos
);
// Note: We don't reset initialized flag or clear the cache tensors
// to avoid reallocation. The cache will be overwritten on next update.
}
}
/**
* @brief Access a specific layer's cache (for advanced usage)
*/
*/
KVCacheLayer
&
layer
(
size_t
layer_idx
)
{
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
if
(
layer_idx
>=
layers_
.
size
())
{
update
(
size_t
layer_idx
,
throw
std
::
runtime_error
(
"DynamicCache: layer_idx out of range"
);
const
infinicore
::
Tensor
&
k
,
}
const
infinicore
::
Tensor
&
v
,
return
layers_
[
layer_idx
];
const
infinicore
::
Tensor
&
cache_positions
);
}
const
KVCacheLayer
&
layer
(
size_t
layer_idx
)
const
{
~
StaticKVCache
()
override
=
default
;
if
(
layer_idx
>=
layers_
.
size
())
{
throw
std
::
runtime_error
(
"DynamicCache: layer_idx out of range"
);
}
return
layers_
[
layer_idx
];
}
private:
private:
CacheConfig
cache_config_
;
infinicore
::
Size
k_dim_
;
std
::
vector
<
KVCacheLayer
>
layers_
;
infinicore
::
Size
v_dim_
;
infinicore
::
Size
num_rank_k_heads_
;
infinicore
::
Size
num_rank_v_heads_
;
infinicore
::
Size
rank_batch_size_
;
infinicore
::
Size
cache_len_
;
infinicore
::
Size
rank_num_layers_
;
infinicore
::
DataType
dtype_
;
// [num_layers, max_batch, num_rank_k_heads, max_cache_len, k_dim]
infinicore
::
Tensor
k_caches_
;
// [num_layers, max_batch, num_rank_v_heads, max_cache_len, v_dim]
infinicore
::
Tensor
v_caches_
;
};
};
}
// namespace infinilm::cache
}
// namespace infinilm::cache
csrc/engine/distributed/communication_group.cpp
View file @
ff00b5c8
...
@@ -38,7 +38,7 @@ int CommunicationGroup::get_world_size() const {
...
@@ -38,7 +38,7 @@ int CommunicationGroup::get_world_size() const {
CommunicationGroup
::~
CommunicationGroup
()
{
CommunicationGroup
::~
CommunicationGroup
()
{
if
(
communicators_
.
size
()
>
1
)
{
if
(
communicators_
.
size
()
>
1
)
{
for
(
auto
&
comm
:
communicators_
)
{
for
(
auto
&
comm
:
communicators_
)
{
RUN_INFINI
(
infinicclCommDestroy
(
comm
)
)
;
infinicclCommDestroy
(
comm
);
}
}
}
}
}
}
...
...
csrc/engine/infer_engine.cpp
View file @
ff00b5c8
...
@@ -10,32 +10,13 @@ InferEngine::InferEngine(
...
@@ -10,32 +10,13 @@ InferEngine::InferEngine(
const
InfinilmModel
::
Config
&
config
,
const
InfinilmModel
::
Config
&
config
,
const
distributed
::
DistConfig
&
distributed_config
,
const
distributed
::
DistConfig
&
distributed_config
,
infinicore
::
Device
::
Type
device_type
,
infinicore
::
Device
::
Type
device_type
,
const
cache
::
CacheConfig
&
cache_config
)
// Changed parameter
const
cache
::
CacheConfig
*
cache_config
)
// Changed parameter
:
communication_group_
(
distributed_config
,
device_type
),
:
communication_group_
(
distributed_config
,
device_type
),
model_config_
(
config
),
model_config_
(
config
)
{
cache_config_
(
cache_config
)
{
spdlog
::
info
(
"Launch InferEngine with {}"
,
std
::
string
(
distributed_config
));
if
(
cache_config
!=
nullptr
)
{
spdlog
::
info
(
"Cache configuration: type={}, layers={}, max_kv_cache_length={}"
,
cache_config_
=
cache_config
->
unique_copy
();
static_cast
<
int
>
(
cache_config_
.
type
),
cache_config_
.
num_layers
,
cache_config_
.
max_kv_cache_length
);
// Try to extract model configuration to override default cache parameters if needed
try
{
if
(
const
auto
llama_config_ptr
=
dynamic_cast
<
const
models
::
llama
::
LlamaConfig
*>
(
&
config
))
{
const
auto
&
llama_config
=
*
llama_config_ptr
;
cache_config_
.
num_layers
=
llama_config
.
num_hidden_layers
;
cache_config_
.
max_kv_cache_length
=
llama_config
.
max_position_embeddings
;
spdlog
::
info
(
"Updated cache config from model: layers={}, max_kv_cache_length={}"
,
cache_config_
.
num_layers
,
cache_config_
.
max_kv_cache_length
);
}
}
catch
(...)
{
spdlog
::
warn
(
"Could not extract model config, using provided CacheConfig"
);
}
}
// Create one RankWorker per rank
// Create one RankWorker per rank
int
world_size
=
communication_group_
.
get_world_size
();
int
world_size
=
communication_group_
.
get_world_size
();
workers_
.
reserve
(
world_size
);
workers_
.
reserve
(
world_size
);
...
@@ -43,7 +24,7 @@ InferEngine::InferEngine(
...
@@ -43,7 +24,7 @@ InferEngine::InferEngine(
workers_
.
emplace_back
(
std
::
make_unique
<
RankWorker
>
(
workers_
.
emplace_back
(
std
::
make_unique
<
RankWorker
>
(
model_config_
,
model_config_
,
communication_group_
.
get_rank_info
(
r
),
communication_group_
.
get_rank_info
(
r
),
cache_config_
));
cache_config_
!=
nullptr
?
cache_config_
.
get
()
:
nullptr
));
}
}
}
}
...
@@ -75,12 +56,14 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
...
@@ -75,12 +56,14 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
//------------------------------------------------------
// forward
// forward
//------------------------------------------------------
//------------------------------------------------------
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
()
const
{
const
auto
&
[
input_ids
,
position_ids
]
=
input
;
return
{
input_ids
,
position_ids
,
cache_positions
};
}
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
// Trigger each worker to run inference
// Trigger each worker to run inference
for
(
auto
&
worker
:
workers_
)
{
for
(
auto
&
worker
:
workers_
)
{
worker
->
run
(
{
input
_ids
,
position_ids
}
);
worker
->
run
(
input
.
to_model_input
()
);
}
}
// Wait for all workers
// Wait for all workers
for
(
auto
&
worker
:
workers_
)
{
for
(
auto
&
worker
:
workers_
)
{
...
@@ -104,25 +87,12 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
...
@@ -104,25 +87,12 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
return
communication_group_
.
get_dist_config
();
return
communication_group_
.
get_dist_config
();
}
}
//------------------------------------------------------
// reset_cache
//------------------------------------------------------
void
InferEngine
::
reset_cache
(
size_t
pos
)
{
for
(
auto
&
worker
:
workers_
)
{
worker
->
reset_cache
(
pos
);
}
for
(
auto
&
worker
:
workers_
)
{
worker
->
wait
();
}
}
//------------------------------------------------------
//------------------------------------------------------
// reset_cache (overloaded with CacheConfig)
// reset_cache (overloaded with CacheConfig)
//------------------------------------------------------
//------------------------------------------------------
void
InferEngine
::
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
)
{
void
InferEngine
::
reset_cache
(
const
cache
::
CacheConfig
*
new_config
)
{
cache_config_
=
new_config
;
for
(
auto
&
worker
:
workers_
)
{
for
(
auto
&
worker
:
workers_
)
{
worker
->
reset_cache
(
new_config
,
pos
);
worker
->
reset_cache
(
new_config
);
}
}
for
(
auto
&
worker
:
workers_
)
{
for
(
auto
&
worker
:
workers_
)
{
worker
->
wait
();
worker
->
wait
();
...
...
csrc/engine/infer_engine.hpp
View file @
ff00b5c8
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include "distributed/distributed.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"
#include "rank_worker.hpp"
#include "../models/infinilm_model.hpp"
#include <any>
#include <vector>
#include <vector>
namespace
infinilm
::
engine
{
namespace
infinilm
::
engine
{
...
@@ -16,6 +16,10 @@ public:
...
@@ -16,6 +16,10 @@ public:
infinicore
::
Tensor
input_ids
;
infinicore
::
Tensor
input_ids
;
infinicore
::
Tensor
position_ids
;
infinicore
::
Tensor
position_ids
;
infinicore
::
Tensor
cache_positions
;
infinilm
::
InfinilmModel
::
Input
to_model_input
()
const
;
};
};
struct
Output
{
struct
Output
{
...
@@ -27,7 +31,7 @@ public:
...
@@ -27,7 +31,7 @@ public:
const
InfinilmModel
::
Config
&
config
,
const
InfinilmModel
::
Config
&
config
,
const
distributed
::
DistConfig
&
distributed_config
=
distributed
::
DistConfig
(),
const
distributed
::
DistConfig
&
distributed_config
=
distributed
::
DistConfig
(),
infinicore
::
Device
::
Type
device_type
=
infinicore
::
context
::
getDevice
().
getType
(),
infinicore
::
Device
::
Type
device_type
=
infinicore
::
context
::
getDevice
().
getType
(),
const
cache
::
CacheConfig
&
cache_config
=
cache
::
CacheConfig
()
);
const
cache
::
CacheConfig
*
cache_config
=
nullptr
);
// Load a parameter to all workers (each can extract its shard inside RankWorker)
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void
load_param
(
const
std
::
string
&
name
,
const
infinicore
::
Tensor
&
param
);
void
load_param
(
const
std
::
string
&
name
,
const
infinicore
::
Tensor
&
param
);
...
@@ -38,24 +42,20 @@ public:
...
@@ -38,24 +42,20 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks
// Run a single forward pass on all workers and return the outputs from all ranks
Output
forward
(
const
Input
&
input
);
Output
forward
(
const
Input
&
input
);
// Reset the internal cache pos in all workers (clears state between generations)
void
reset_cache
(
const
cache
::
CacheConfig
*
new_config
);
void
reset_cache
(
size_t
pos
=
0
);
// Overload: reset cache with new KV configuration
void
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
=
0
);
~
InferEngine
();
~
InferEngine
();
const
distributed
::
DistConfig
&
get_dist_config
()
const
;
const
distributed
::
DistConfig
&
get_dist_config
()
const
;
// Get current KV configuration
// Get current KV configuration
const
cache
::
CacheConfig
&
get_cache_config
()
const
{
return
cache_config_
;
}
const
cache
::
CacheConfig
*
get_cache_config
()
const
{
return
cache_config_
.
get
()
;
}
protected:
protected:
std
::
vector
<
std
::
unique_ptr
<
RankWorker
>>
workers_
;
std
::
vector
<
std
::
unique_ptr
<
RankWorker
>>
workers_
;
distributed
::
CommunicationGroup
communication_group_
;
distributed
::
CommunicationGroup
communication_group_
;
const
InfinilmModel
::
Config
&
model_config_
;
const
InfinilmModel
::
Config
&
model_config_
;
cache
::
CacheConfig
cache_config_
;
std
::
unique_ptr
<
cache
::
CacheConfig
>
cache_config_
;
};
};
}
// namespace infinilm::engine
}
// namespace infinilm::engine
csrc/engine/rank_worker.cpp
View file @
ff00b5c8
...
@@ -10,15 +10,17 @@ namespace infinilm::engine {
...
@@ -10,15 +10,17 @@ namespace infinilm::engine {
RankWorker
::
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
RankWorker
::
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
&
cache_config
)
const
cache
::
CacheConfig
*
cache_config
)
:
model_config_
(
model_config
),
:
model_config_
(
model_config
),
rank_info_
(
rank_info
),
rank_info_
(
rank_info
),
job_cmd_
(
Command
::
INIT
),
job_cmd_
(
Command
::
INIT
),
has_job_
(
false
),
has_job_
(
false
),
job_done_
(
false
),
job_done_
(
false
),
should_exit_
(
false
),
should_exit_
(
false
),
init_done_
(
false
),
init_done_
(
false
)
{
pending_cache_config_
(
cache_config
)
{
if
(
cache_config
!=
nullptr
)
{
pending_cache_config_
=
cache_config
->
unique_copy
();
}
// start the thread
// start the thread
thread_
=
std
::
thread
(
&
RankWorker
::
thread_loop
,
this
);
thread_
=
std
::
thread
(
&
RankWorker
::
thread_loop
,
this
);
...
@@ -80,7 +82,14 @@ void RankWorker::load_param(const std::string &name,
...
@@ -80,7 +82,14 @@ void RankWorker::load_param(const std::string &name,
// state_dict --
// state_dict --
//------------------------------------------------------
//------------------------------------------------------
std
::
unordered_map
<
std
::
string
,
infinicore
::
nn
::
Parameter
>
RankWorker
::
state_dict
()
{
std
::
unordered_map
<
std
::
string
,
infinicore
::
nn
::
Parameter
>
RankWorker
::
state_dict
()
{
return
this
->
model_
->
state_dict
();
std
::
unique_lock
<
std
::
mutex
>
lk
(
mutex_
);
cv_
.
wait
(
lk
,
[
&
]
{
return
init_done_
||
should_exit_
;
});
if
(
!
model_
)
{
throw
std
::
runtime_error
(
"state_dict called before model initialization"
);
}
return
model_
->
state_dict
();
}
}
//------------------------------------------------------
//------------------------------------------------------
...
@@ -113,32 +122,15 @@ void RankWorker::wait() {
...
@@ -113,32 +122,15 @@ void RankWorker::wait() {
}
}
}
}
//------------------------------------------------------
void
RankWorker
::
reset_cache
(
const
cache
::
CacheConfig
*
new_config
)
{
// reset_cache -- synchronous by default, async optional (unstable)
//------------------------------------------------------
void
RankWorker
::
reset_cache
(
size_t
pos
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
should_exit_
)
{
throw
std
::
runtime_error
(
"RankWorker is closing; cannot reset_cache"
);
}
pending_reset_pos_
=
pos
;
job_cmd_
=
Command
::
RESET_CACHE
;
has_job_
=
true
;
job_done_
=
false
;
cv_
.
notify_all
();
}
void
RankWorker
::
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
should_exit_
)
{
if
(
should_exit_
)
{
throw
std
::
runtime_error
(
"RankWorker is closing; cannot reset_cache"
);
throw
std
::
runtime_error
(
"RankWorker is closing; cannot reset_cache"
);
}
}
// Store both the position and the new config
// Store both the position and the new config
pending_reset_pos_
=
pos
;
pending_cache_config_
=
new_config
->
unique_copy
();
pending_cache_config_
=
new_config
;
job_cmd_
=
Command
::
RESET_CACHE
;
job_cmd_
=
Command
::
RESET_CACHE_WITH_CONFIG
;
has_job_
=
true
;
has_job_
=
true
;
job_done_
=
false
;
job_done_
=
false
;
cv_
.
notify_all
();
cv_
.
notify_all
();
...
@@ -174,17 +166,17 @@ InfinilmModel::Output RankWorker::get_output() {
...
@@ -174,17 +166,17 @@ InfinilmModel::Output RankWorker::get_output() {
//------------------------------------------------------
//------------------------------------------------------
void
RankWorker
::
thread_loop
()
{
void
RankWorker
::
thread_loop
()
{
try
{
try
{
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore
::
context
::
setDevice
(
rank_info_
.
device
);
cache_ptr_
=
std
::
make_shared
<
cache
::
DynamicCache
>
(
pending_cache_config_
);
// Create model using factory (may be expensive)
model_
=
InfinilmModelFactory
::
createModel
(
model_config_
,
rank_info_
,
cache_ptr_
);
// Signal that initialization is done
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore
::
context
::
setDevice
(
rank_info_
.
device
);
// Create model using factory (may be expensive)
model_
=
InfinilmModelFactory
::
createModel
(
model_config_
,
rank_info_
,
pending_cache_config_
!=
nullptr
?
pending_cache_config_
.
get
()
:
nullptr
);
if
(
!
model_
)
{
throw
std
::
runtime_error
(
"Failed to create model"
);
}
init_done_
=
true
;
init_done_
=
true
;
}
}
cv_
.
notify_all
();
cv_
.
notify_all
();
...
@@ -195,8 +187,7 @@ void RankWorker::thread_loop() {
...
@@ -195,8 +187,7 @@ void RankWorker::thread_loop() {
std
::
string
local_param_name
;
std
::
string
local_param_name
;
infinicore
::
Tensor
local_param
;
infinicore
::
Tensor
local_param
;
InfinilmModel
::
Input
local_args
;
InfinilmModel
::
Input
local_args
;
size_t
local_reset_pos
=
0
;
std
::
unique_ptr
<
cache
::
CacheConfig
>
local_cache_config
;
cache
::
CacheConfig
local_reset_config
;
// Wait for a job or exit
// Wait for a job or exit
{
{
...
@@ -215,12 +206,10 @@ void RankWorker::thread_loop() {
...
@@ -215,12 +206,10 @@ void RankWorker::thread_loop() {
}
else
if
(
local_cmd
==
Command
::
RUN
)
{
}
else
if
(
local_cmd
==
Command
::
RUN
)
{
local_args
=
pending_args_
;
local_args
=
pending_args_
;
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
local_reset_pos
=
pending_reset_pos_
;
if
(
pending_cache_config_
!=
nullptr
)
{
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE_WITH_CONFIG
)
{
local_cache_config
=
pending_cache_config_
->
unique_copy
();
local_reset_pos
=
pending_reset_pos_
;
}
local_reset_config
=
pending_cache_config_
;
}
}
// mark job as being processed
// mark job as being processed
has_job_
=
false
;
has_job_
=
false
;
job_done_
=
false
;
job_done_
=
false
;
...
@@ -270,14 +259,7 @@ void RankWorker::thread_loop() {
...
@@ -270,14 +259,7 @@ void RankWorker::thread_loop() {
}
}
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
try
{
try
{
// Option 1: Use model's reset_cache if it handles cache
model_
->
reset_cache
(
local_cache_config
!=
nullptr
?
local_cache_config
.
get
()
:
nullptr
);
model_
->
reset_cache
(
local_reset_pos
);
// Option 2: Reset cache directly if we have access
// if (cache_ptr_ != nullptr) {
// auto* dynamic_cache = static_cast<cache::DynamicCache*>(cache_ptr_);
// dynamic_cache->reset(local_reset_pos);
// }
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
...
@@ -293,25 +275,6 @@ void RankWorker::thread_loop() {
...
@@ -293,25 +275,6 @@ void RankWorker::thread_loop() {
spdlog
::
error
(
"[{}] exception during reset_cache: {}
\n
"
,
info
(),
e
.
what
());
spdlog
::
error
(
"[{}] exception during reset_cache: {}
\n
"
,
info
(),
e
.
what
());
break
;
break
;
}
}
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE_WITH_CONFIG
)
{
try
{
// Use model's reset_cache with new configuration
model_
->
reset_cache
(
local_reset_config
,
local_reset_pos
);
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
job_done_
=
true
;
}
cv_
.
notify_all
();
}
catch
(
const
std
::
exception
&
e
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
should_exit_
=
true
;
job_done_
=
true
;
cv_
.
notify_all
();
spdlog
::
error
(
"[{}] exception during reset_cache with config: {}
\n
"
,
info
(),
e
.
what
());
break
;
}
}
else
{
}
else
{
// Shouldn't reach here (no-op)
// Shouldn't reach here (no-op)
}
}
...
...
csrc/engine/rank_worker.hpp
View file @
ff00b5c8
...
@@ -19,14 +19,13 @@ class RankWorker {
...
@@ -19,14 +19,13 @@ class RankWorker {
LOAD
,
LOAD
,
RUN
,
RUN
,
RESET_CACHE
,
RESET_CACHE
,
RESET_CACHE_WITH_CONFIG
,
STOP
STOP
};
};
public:
public:
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
&
cache_config
);
const
cache
::
CacheConfig
*
cache_config
);
// Submit a parameter load job and wait until the load completes on the worker thread.
// Submit a parameter load job and wait until the load completes on the worker thread.
void
load_param
(
const
std
::
string
&
name
,
void
load_param
(
const
std
::
string
&
name
,
...
@@ -38,11 +37,8 @@ public:
...
@@ -38,11 +37,8 @@ public:
// Submit a run (forward) job.
// Submit a run (forward) job.
void
run
(
const
InfinilmModel
::
Input
&
args
);
void
run
(
const
InfinilmModel
::
Input
&
args
);
// Reset the internal cache in the model (clears state between generations)
void
reset_cache
(
size_t
pos
=
0
);
// Reset the internal cache with a new configuration
// Reset the internal cache with a new configuration
void
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
=
0
);
void
reset_cache
(
const
cache
::
CacheConfig
*
new_config
);
// Wait until run job completes. The result can be retrieved with get_output().
// Wait until run job completes. The result can be retrieved with get_output().
void
wait
();
void
wait
();
...
@@ -63,7 +59,7 @@ private:
...
@@ -63,7 +59,7 @@ private:
const
InfinilmModel
::
Config
&
model_config_
;
const
InfinilmModel
::
Config
&
model_config_
;
distributed
::
RankInfo
rank_info_
;
distributed
::
RankInfo
rank_info_
;
std
::
shared_ptr
<
InfinilmModel
>
model_
;
std
::
shared_ptr
<
InfinilmModel
>
model_
;
std
::
shared_ptr
<
cache
::
Dynamic
Cache
>
cache_
ptr_
;
std
::
shared_ptr
<
cache
::
Cache
>
cache_
;
// Command for the pending job (protected by mutex_)
// Command for the pending job (protected by mutex_)
Command
job_cmd_
;
Command
job_cmd_
;
...
@@ -78,8 +74,7 @@ private:
...
@@ -78,8 +74,7 @@ private:
std
::
string
pending_param_name_
;
std
::
string
pending_param_name_
;
infinicore
::
Tensor
pending_param_
;
infinicore
::
Tensor
pending_param_
;
InfinilmModel
::
Input
pending_args_
;
InfinilmModel
::
Input
pending_args_
;
size_t
pending_reset_pos_
=
0
;
std
::
unique_ptr
<
cache
::
CacheConfig
>
pending_cache_config_
;
cache
::
CacheConfig
pending_cache_config_
;
// Output (protected by mutex)
// Output (protected by mutex)
InfinilmModel
::
Output
output_
;
InfinilmModel
::
Output
output_
;
...
...
csrc/models/infinilm_model.hpp
View file @
ff00b5c8
...
@@ -18,12 +18,10 @@ public:
...
@@ -18,12 +18,10 @@ public:
struct
Input
{
struct
Input
{
/// Token IDs tensor of shape `[batch, seq_len]`.
/// Token IDs tensor of shape `[batch, seq_len]`.
infinicore
::
Tensor
input_ids
;
infinicore
::
Tensor
input_ids
;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore
::
Tensor
position_ids
;
infinicore
::
Tensor
position_ids
;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
/// Optional model-level KV cache for incremental decoding. Defaults to `nullptr`.
infinicore
::
Tensor
cache_positions
;
void
*
kv_cache
=
nullptr
;
};
};
struct
Output
{
struct
Output
{
...
@@ -33,8 +31,7 @@ public:
...
@@ -33,8 +31,7 @@ public:
virtual
~
InfinilmModel
()
=
default
;
virtual
~
InfinilmModel
()
=
default
;
virtual
Output
forward
(
const
Input
&
input
)
const
=
0
;
virtual
Output
forward
(
const
Input
&
input
)
const
=
0
;
// Optional: reset cache; default no-op for models without cache
virtual
void
reset_cache
(
size_t
pos
=
0
)
{}
virtual
void
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
=
0
;
virtual
void
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
=
0
)
=
0
;
};
};
}
// namespace infinilm
}
// namespace infinilm
csrc/models/llama/llama_attention.cpp
View file @
ff00b5c8
...
@@ -51,7 +51,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
...
@@ -51,7 +51,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
)
const
{
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
const
infinicore
::
Tensor
&
cache_positions
)
const
{
if
(
!
rotary_emb_
)
{
if
(
!
rotary_emb_
)
{
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
}
}
...
@@ -97,16 +98,15 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -97,16 +98,15 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
q_reshaped
=
q_rope
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_q_head, seq_len, head_dim]
q_reshaped
=
q_rope
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_q_head, seq_len, head_dim]
auto
k_permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
k_permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
infinilm
::
cache
::
DynamicCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
DynamicCache
*>
(
kv_cache
);
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
if
(
auto
static_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
StaticKVCache
>
(
kv_cache
))
{
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
cache_positions
);
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
);
k_total
=
k_total_tmp
;
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
}
else
{
// No external cache - this shouldn't happen in normal operation, but handle gracefully
throw
std
::
runtime_error
(
"LlamaAttention: kv
_
cache
is required but nullptr provided
"
);
throw
std
::
runtime_error
(
"LlamaAttention:
Unsupported
kvcache
type
"
);
}
}
auto
total_seq_len
=
k_total
->
shape
()[
2
];
auto
total_seq_len
=
k_total
->
shape
()[
2
];
...
...
csrc/models/llama/llama_attention.hpp
View file @
ff00b5c8
...
@@ -50,7 +50,8 @@ public:
...
@@ -50,7 +50,8 @@ public:
*/
*/
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
=
nullptr
)
const
;
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
const
infinicore
::
Tensor
&
cache_positions
)
const
;
/**
/**
* @brief Get the layer index
* @brief Get the layer index
...
...
csrc/models/llama/llama_decoder_layer.cpp
View file @
ff00b5c8
...
@@ -23,7 +23,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
...
@@ -23,7 +23,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore
::
Tensor
LlamaDecoderLayer
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaDecoderLayer
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
)
const
{
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
const
infinicore
::
Tensor
&
cache_positions
)
const
{
// Save residual for attention
// Save residual for attention
auto
residual
=
hidden_states
;
auto
residual
=
hidden_states
;
...
@@ -31,7 +32,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
...
@@ -31,7 +32,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto
normed_states
=
input_layernorm_
->
forward
(
hidden_states
);
auto
normed_states
=
input_layernorm_
->
forward
(
hidden_states
);
// 2. Self-attention with residual connection
// 2. Self-attention with residual connection
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
);
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
,
cache_positions
);
// Add residual: hidden_states = hidden_states + attn_output
// Add residual: hidden_states = hidden_states + attn_output
auto
output
=
infinicore
::
op
::
add
(
residual
,
attn_output
);
auto
output
=
infinicore
::
op
::
add
(
residual
,
attn_output
);
...
...
csrc/models/llama/llama_decoder_layer.hpp
View file @
ff00b5c8
...
@@ -48,7 +48,8 @@ public:
...
@@ -48,7 +48,8 @@ public:
*/
*/
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
=
nullptr
)
const
;
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
const
infinicore
::
Tensor
&
cache_positions
)
const
;
/**
/**
* @brief Get the layer index
* @brief Get the layer index
...
...
csrc/models/llama/llama_for_causal_lm.cpp
View file @
ff00b5c8
...
@@ -26,11 +26,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
...
@@ -26,11 +26,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
}
}
LlamaForCausalLM
::
Output
LlamaForCausalLM
::
forward
(
const
Input
&
input
)
const
{
LlamaForCausalLM
::
Output
LlamaForCausalLM
::
forward
(
const
Input
&
input
)
const
{
const
auto
&
[
input_ids
,
position_ids
,
kv_
cache
]
=
input
;
const
auto
&
[
input_ids
,
position_ids
,
cache
_position
]
=
input
;
// 1. Forward through base model to get hidden states
// 1. Forward through base model to get hidden states
auto
position_ids_device
=
position_ids
->
to
(
device_
);
auto
position_ids_device
=
position_ids
->
to
(
device_
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids_device
,
kv_
cache
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids_device
,
cache
_position
);
// 2. Apply language modeling head to get logits
// 2. Apply language modeling head to get logits
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
...
@@ -38,12 +38,8 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
...
@@ -38,12 +38,8 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
return
{
logits
};
return
{
logits
};
}
}
void
LlamaForCausalLM
::
reset_cache
(
size_t
pos
)
{
void
LlamaForCausalLM
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
model_
->
reset_cache
(
pos
);
model_
->
reset_cache
(
cache_config
);
}
void
LlamaForCausalLM
::
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
)
{
model_
->
reset_cache
(
new_config
,
pos
);
}
}
}
// namespace infinilm::models::llama
}
// namespace infinilm::models::llama
csrc/models/llama/llama_for_causal_lm.hpp
View file @
ff00b5c8
...
@@ -40,9 +40,7 @@ public:
...
@@ -40,9 +40,7 @@ public:
*/
*/
Output
forward
(
const
Input
&
input
)
const
;
Output
forward
(
const
Input
&
input
)
const
;
// Reset internal cache position
void
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
override
;
void
reset_cache
(
size_t
pos
=
0
)
override
;
void
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
)
override
;
// Module information
// Module information
const
LlamaConfig
&
config
()
const
{
return
model_
->
config
();
}
const
LlamaConfig
&
config
()
const
{
return
model_
->
config
();
}
...
...
csrc/models/llama/llama_model.cpp
View file @
ff00b5c8
...
@@ -10,9 +10,8 @@ namespace infinilm::models::llama {
...
@@ -10,9 +10,8 @@ namespace infinilm::models::llama {
LlamaModel
::
LlamaModel
(
const
LlamaConfig
&
config
,
LlamaModel
::
LlamaModel
(
const
LlamaConfig
&
config
,
const
infinicore
::
Device
&
device
,
const
infinicore
::
Device
&
device
,
engine
::
distributed
::
RankInfo
rank_info
)
engine
::
distributed
::
RankInfo
rank_info
)
:
config_
(
config
)
{
:
config_
(
config
)
,
rank_info_
(
rank_info
)
{
const
auto
&
dtype
{
config
.
dtype
};
const
auto
&
dtype
{
config
.
dtype
};
// Initialize token embeddings
// Initialize token embeddings
INFINICORE_NN_MODULE_INIT
(
embed_tokens
,
config
.
vocab_size
,
config
.
hidden_size
,
INFINICORE_NN_MODULE_INIT
(
embed_tokens
,
config
.
vocab_size
,
config
.
hidden_size
,
std
::
nullopt
,
dtype
,
device
);
std
::
nullopt
,
dtype
,
device
);
...
@@ -46,72 +45,46 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
...
@@ -46,72 +45,46 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore
::
Tensor
LlamaModel
::
forward
(
const
infinicore
::
Tensor
&
input_ids
,
infinicore
::
Tensor
LlamaModel
::
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
)
const
{
const
infinicore
::
Tensor
&
cache_positions
)
const
{
// Use persistent internal cache if no external cache is provided
// This matches Python backend behavior: if use_cache and past_key_values is None, create DynamicCache
// The cache persists across forward calls to enable incremental decoding
void
*
cache_to_use
=
kv_cache
;
if
(
cache_to_use
==
nullptr
)
{
// Create or reuse persistent internal cache at model level
// This ensures the cache persists across multiple forward calls (prefill -> decode -> decode...)
if
(
external_cache_
!=
nullptr
)
{
cache_to_use
=
external_cache_
;
}
else
{
// Fall back to internal cache
if
(
!
internal_cache_
)
{
internal_cache_
=
std
::
make_unique
<
infinilm
::
cache
::
DynamicCache
>
(
config_
.
num_hidden_layers
,
config_
.
max_position_embeddings
);
}
cache_to_use
=
internal_cache_
.
get
();
}
}
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
// 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size]
auto
hidden_states
=
embed_tokens_
->
forward
(
input_ids
);
auto
hidden_states
=
embed_tokens_
->
forward
(
input_ids
);
// 2. Process through all decoder layers
// 2. Process through all decoder layers
size_t
num_layers
=
layers_
.
size
();
size_t
num_layers
=
layers_
.
size
();
for
(
size_t
i
=
0
;
i
<
num_layers
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_layers
;
++
i
)
{
// Pass model-level cache (layer index is now a property of the layer)
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
cache_positions
);
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
cache_to_use
);
// DEBUG: Disabled previous final layer logging
// Logging moved to decoder layer for post-attention normalization
}
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto
shape
=
hidden_states
->
shape
();
auto
shape
=
hidden_states
->
shape
();
size_t
seq_len
=
shape
[
1
];
size_t
seq_len
=
shape
[
1
];
auto
last_token
=
hidden_states
->
narrow
({{
1
,
seq_len
-
1
,
1
}});
auto
last_token
=
hidden_states
->
narrow
({{
1
,
seq_len
-
1
,
1
}});
// DEBUG: Disabled previous final layer normalization logging
// Normalize only the last token (matches Python backend)
auto
normalized_last_token
=
norm_
->
forward
(
last_token
);
auto
normalized_last_token
=
norm_
->
forward
(
last_token
);
return
normalized_last_token
;
return
normalized_last_token
;
}
}
void
LlamaModel
::
reset_cache
(
size_t
pos
)
const
{
void
LlamaModel
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
if
(
internal_cache_
)
{
if
(
cache_config
==
nullptr
)
{
internal_cache_
->
reset
(
pos
);
kv_cache_
=
nullptr
;
}
return
;
if
(
external_cache_
)
{
external_cache_
->
reset
(
pos
);
}
}
void
LlamaModel
::
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
)
const
{
if
(
internal_cache_
)
{
internal_cache_
->
update_config
(
new_config
);
internal_cache_
->
reset
(
pos
);
}
}
if
(
external_cache_
)
{
if
(
auto
kv_cache_config
=
dynamic_cast
<
const
cache
::
StaticKVCacheConfig
*>
(
cache_config
))
{
external_cache_
->
update_config
(
new_config
);
kv_cache_
=
std
::
make_shared
<
cache
::
StaticKVCache
>
(
external_cache_
->
reset
(
pos
);
config_
.
head_dim
,
config_
.
head_dim
,
config_
.
num_key_value_heads
,
config_
.
num_key_value_heads
,
config_
.
num_hidden_layers
,
config_
.
max_position_embeddings
,
config_
.
dtype
,
*
kv_cache_config
,
rank_info_
);
}
else
{
throw
std
::
runtime_error
(
"Unsupported cache type"
);
}
}
}
}
...
...
csrc/models/llama/llama_model.hpp
View file @
ff00b5c8
...
@@ -47,41 +47,19 @@ public:
...
@@ -47,41 +47,19 @@ public:
*
*
* @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param
kv_
cache
Optional model-level KV cache for incremental decoding
* @param cache
_positions Cache positions tensor of shape [n_req]
* @return Output tensor of shape [batch, seq_len, hidden_size]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
*/
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
input_ids
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
void
*
kv_cache
=
nullptr
)
const
;
const
infinicore
::
Tensor
&
cache_positions
)
const
;
void
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
);
// Module information
// Module information
const
LlamaConfig
&
config
()
const
{
return
config_
;
}
const
LlamaConfig
&
config
()
const
{
return
config_
;
}
size_t
num_layers
()
const
{
return
config_
.
num_hidden_layers
;
}
size_t
num_layers
()
const
{
return
config_
.
num_hidden_layers
;
}
/**
* @brief Reset the internal cache to a specific position
* This should be called when starting a new generation sequence to prevent state
* from persisting between different questions/prompts
* @param pos Position to reset to (defaults to 0)
*/
void
reset_cache
(
size_t
pos
=
0
)
const
;
/**
* @brief Reset the internal cache with a new configuration and position
* This should be called when changing cache parameters (e.g., initial capacity)
* @param new_config New cache configuration
* @param pos Position to reset to
*/
void
reset_cache
(
const
cache
::
CacheConfig
&
new_config
,
size_t
pos
=
0
)
const
;
/**
* @brief Set external cache for the model
* @param cache Pointer to external cache (managed by CacheManager)
*/
void
set_external_cache
(
std
::
shared_ptr
<
cache
::
DynamicCache
>
cache
)
{
external_cache_
=
cache
.
get
();
}
protected:
protected:
// Token embeddings
// Token embeddings
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
Embedding
,
embed_tokens
);
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
Embedding
,
embed_tokens
);
...
@@ -95,13 +73,12 @@ protected:
...
@@ -95,13 +73,12 @@ protected:
// Rotary Position Embeddings (shared across all layers)
// Rotary Position Embeddings (shared across all layers)
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
RoPE
,
rotary_emb
);
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
RoPE
,
rotary_emb
);
engine
::
distributed
::
RankInfo
rank_info_
;
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache_
;
private:
private:
LlamaConfig
config_
;
LlamaConfig
config_
;
// Persistent cache for when no external cache is provided
// Mutable because it's not part of the model's learned parameters,
// but needs to persist across forward calls for incremental decoding
mutable
std
::
unique_ptr
<
infinilm
::
cache
::
DynamicCache
>
internal_cache_
;
cache
::
DynamicCache
*
external_cache_
=
nullptr
;
};
};
}
// namespace infinilm::models::llama
}
// namespace infinilm::models::llama
csrc/models/model_factory.cpp
View file @
ff00b5c8
...
@@ -5,20 +5,21 @@ namespace infinilm {
...
@@ -5,20 +5,21 @@ namespace infinilm {
std
::
shared_ptr
<
InfinilmModel
>
InfinilmModelFactory
::
createModel
(
std
::
shared_ptr
<
InfinilmModel
>
InfinilmModelFactory
::
createModel
(
const
InfinilmModel
::
Config
&
config
,
const
InfinilmModel
::
Config
&
config
,
engine
::
distributed
::
RankInfo
rank_info
,
engine
::
distributed
::
RankInfo
rank_info
,
std
::
shared_ptr
<
cache
::
DynamicCache
>
cache
_ptr
)
{
const
cache
::
CacheConfig
*
cache
)
{
std
::
shared_ptr
<
InfinilmModel
>
model
;
if
(
const
auto
llama_config_ptr
=
dynamic_cast
<
const
models
::
llama
::
LlamaConfig
*>
(
&
config
))
{
if
(
const
auto
llama_config_ptr
=
dynamic_cast
<
const
models
::
llama
::
LlamaConfig
*>
(
&
config
))
{
const
auto
&
llama_config
=
*
llama_config_ptr
;
const
auto
&
llama_config
=
*
llama_config_ptr
;
auto
model
=
std
::
make_shared
<
models
::
llama
::
LlamaForCausalLM
>
(
model
=
std
::
make_shared
<
models
::
llama
::
LlamaForCausalLM
>
(
llama_config
,
rank_info
.
device
,
rank_info
);
llama_config
,
rank_info
.
device
,
rank_info
);
if
(
cache_ptr
!=
nullptr
)
{
model
->
model
().
set_external_cache
(
cache_ptr
);
}
return
model
;
}
else
{
}
else
{
throw
std
::
invalid_argument
(
"InfinilmModelFactory::createModel: Unsupported model config type"
);
throw
std
::
invalid_argument
(
"InfinilmModelFactory::createModel: Unsupported model config type"
);
}
}
if
(
cache
)
{
model
->
reset_cache
(
cache
);
}
return
model
;
}
}
}
// namespace infinilm
}
// namespace infinilm
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment