Commit 06be2dc3 authored by wooway777's avatar wooway777
Browse files

issue/231 - remove ninetoothed dependency by default

parent a4ced800
...@@ -63,6 +63,12 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ...@@ -63,6 +63,12 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
``` ```
- 选择是否使用九齿计算路径,默认为false,即不依赖九齿算子
```bash
xmake f --ninetoothed= [true | false] -cv
```
- 安装 InfiniLM Python 包 - 安装 InfiniLM Python 包
```bash ```bash
pip install -e . pip install -e .
......
...@@ -93,26 +93,24 @@ StaticKVCache::update(size_t layer_idx, ...@@ -93,26 +93,24 @@ StaticKVCache::update(size_t layer_idx,
auto device = k_cache_layer->device(); auto device = k_cache_layer->device();
if (device.getType() == infinicore::Device::Type::NVIDIA #ifdef ENABLE_NINETOOTHED
|| device.getType() == infinicore::Device::Type::ILUVATAR infinicore::op::kv_caching_(
|| device.getType() == infinicore::Device::Type::METAX) { k_cache_layer,
infinicore::op::kv_caching_( v_cache_layer,
k_cache_layer, k,
v_cache_layer, v,
k, past_sequence_lengths);
v, #else
past_sequence_lengths); size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
} else { auto result_len = cache_pos + update_len;
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; ASSERT(result_len <= cache_len_);
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_); auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); k_cache_update->copy_from(k);
v_cache_update->copy_from(v);
k_cache_update->copy_from(k); #endif
v_cache_update->copy_from(v);
}
return {k_cache_layer, v_cache_layer}; return {k_cache_layer, v_cache_layer};
} }
......
...@@ -8,6 +8,16 @@ set_toolchains("gcc") ...@@ -8,6 +8,16 @@ set_toolchains("gcc")
add_includedirs("third_party/spdlog/include") add_includedirs("third_party/spdlog/include")
add_includedirs("third_party/json/single_include/") add_includedirs("third_party/json/single_include/")
option("ninetoothed")
set_default(false)
set_showmenu(true)
set_description("Whether to complie NineToothed specifc path")
option_end()
if has_config("ninetoothed") then
add_defines("ENABLE_NINETOOTHED")
end
target("infinicore_infer") target("infinicore_infer")
set_kind("shared") set_kind("shared")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment