Commit ee59b3f5 authored by wooway777's avatar wooway777
Browse files

issue/214 - update attn and caching logics

parent 67e8d6e9
......@@ -96,7 +96,6 @@ StaticKVCache::update(size_t layer_idx,
if (device.getType() == infinicore::Device::Type::NVIDIA
|| device.getType() == infinicore::Device::Type::ILUVATAR
|| device.getType() == infinicore::Device::Type::METAX
|| device.getType() == infinicore::Device::Type::MOORE
|| device.getType() == infinicore::Device::Type::CAMBRICON) {
infinicore::op::kv_caching_(
k_cache_layer,
......
......@@ -127,8 +127,6 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
infinicore::Tensor attn_output;
if (q_reshaped->device().getType() == infinicore::Device::Type::NVIDIA
|| q_reshaped->device().getType() == infinicore::Device::Type::METAX
|| q_reshaped->device().getType() == infinicore::Device::Type::MOORE
|| q_reshaped->device().getType() == infinicore::Device::Type::ILUVATAR
|| q_reshaped->device().getType() == infinicore::Device::Type::CAMBRICON) {
attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scaling_, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment