"googlemock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "9518a57428ae0a7ed450c1361768e84a2a38af5a"
Unverified Commit d09de04c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #250 from InfiniTensor/issue/248

Issue/248 support flash-attention
parents f67956fe 5dc85bf4
...@@ -252,6 +252,13 @@ def get_args(): ...@@ -252,6 +252,13 @@ def get_args():
action="store_true", action="store_true",
help="Perform a warmup run before benchmarking/inference.", help="Perform a warmup run before benchmarking/inference.",
) )
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args() return parser.parse_args()
...@@ -278,6 +285,7 @@ class TestModel: ...@@ -278,6 +285,7 @@ class TestModel:
skip_load=False, skip_load=False,
cache_config=None, cache_config=None,
enable_graph=False, enable_graph=False,
attn_backend="default",
) -> None: ) -> None:
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -289,6 +297,7 @@ class TestModel: ...@@ -289,6 +297,7 @@ class TestModel:
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
cache_config=cache_config, cache_config=cache_config,
enable_graph_compiling=enable_graph, enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -461,6 +470,7 @@ if __name__ == "__main__": ...@@ -461,6 +470,7 @@ if __name__ == "__main__":
skip_load=skip_load, skip_load=skip_load,
cache_config=cache_config, cache_config=cache_config,
enable_graph=enable_graph, enable_graph=enable_graph,
attn_backend=args.attn,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
......
...@@ -142,6 +142,14 @@ def get_args(): ...@@ -142,6 +142,14 @@ def get_args():
help="sampling temperature", help="sampling temperature",
) )
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args() return parser.parse_args()
...@@ -156,6 +164,7 @@ def test( ...@@ -156,6 +164,7 @@ def test(
top_k=1, top_k=1,
top_p=1.0, top_p=1.0,
temperature=1.0, temperature=1.0,
attn_backend="default",
): ):
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
...@@ -166,6 +175,7 @@ def test( ...@@ -166,6 +175,7 @@ def test(
device=infini_device, device=infini_device,
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph, enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Load Weights # Load Weights
...@@ -333,4 +343,5 @@ if __name__ == "__main__": ...@@ -333,4 +343,5 @@ if __name__ == "__main__":
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
temperature=args.temperature, temperature=args.temperature,
attn_backend=args.attn,
) )
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <infinirt.h> #include <infinirt.h>
__C __export struct KVCache *createKVCache( __INFINI_C __export struct KVCache *createKVCache(
size_t nlayers, size_t nlayers,
size_t max_len, size_t max_len,
size_t nkvh_, size_t nkvh_,
...@@ -14,8 +14,8 @@ __C __export struct KVCache *createKVCache( ...@@ -14,8 +14,8 @@ __C __export struct KVCache *createKVCache(
int *dev_ids, int *dev_ids,
size_t ndev); size_t ndev);
__C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len); __INFINI_C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len);
__C __export void dropKVCache(KVCache *kv_cache); __INFINI_C __export void dropKVCache(KVCache *kv_cache);
#endif /* CACHE_H */ #endif /* CACHE_H */
...@@ -103,26 +103,26 @@ typedef struct { ...@@ -103,26 +103,26 @@ typedef struct {
/// @param device 协处理器种类 /// @param device 协处理器种类
/// @param ndev 协处理器数量 /// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev /// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct DeepSeekV3Model * __INFINI_C __export struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *, createDeepSeekV3Model(const DeepSeekV3Meta *,
const DeepSeekV3Weights *); const DeepSeekV3Weights *);
__C DeepSeekV3Weights * __INFINI_C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta, createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
const int *dev_ids); const int *dev_ids);
__C __export DeepSeekV3WeightLoader * __INFINI_C __export DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader(); createDeepSeekV3WeightLoader();
/// @brief 销毁模型 /// @brief 销毁模型
__C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *); __INFINI_C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *);
__C __export struct DeepSeekV3Cache * __INFINI_C __export struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *); createDeepSeekV3Cache(const struct DeepSeekV3Model *);
__C __export void __INFINI_C __export void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *, dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
struct DeepSeekV3Cache *); struct DeepSeekV3Cache *);
...@@ -137,7 +137,7 @@ dropDeepSeekV3Cache(const struct DeepSeekV3Model *, ...@@ -137,7 +137,7 @@ dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
inferBatchDeepSeekV3(struct DeepSeekV3Model *, inferBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -153,7 +153,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *, ...@@ -153,7 +153,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *,
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *, forwardBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
...@@ -54,7 +54,7 @@ typedef struct ...@@ -54,7 +54,7 @@ typedef struct
/// @param device 协处理器种类 /// @param device 协处理器种类
/// @param ndev 协处理器数量 /// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev /// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeModel * __INFINI_C __export struct JiugeModel *
createJiugeModel(const JiugeMeta *, createJiugeModel(const JiugeMeta *,
const JiugeWeights *, const JiugeWeights *,
infiniDevice_t device, infiniDevice_t device,
...@@ -62,7 +62,7 @@ createJiugeModel(const JiugeMeta *, ...@@ -62,7 +62,7 @@ createJiugeModel(const JiugeMeta *,
const int *dev_ids); const int *dev_ids);
/// @brief 销毁模型 /// @brief 销毁模型
__C __export void __INFINI_C __export void
destroyJiugeModel(struct JiugeModel *); destroyJiugeModel(struct JiugeModel *);
/// @brief 批次推理一轮,并采样出新的 token /// @brief 批次推理一轮,并采样出新的 token
...@@ -76,7 +76,7 @@ destroyJiugeModel(struct JiugeModel *); ...@@ -76,7 +76,7 @@ destroyJiugeModel(struct JiugeModel *);
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
inferBatchJiuge(struct JiugeModel *, inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -92,7 +92,7 @@ inferBatchJiuge(struct JiugeModel *, ...@@ -92,7 +92,7 @@ inferBatchJiuge(struct JiugeModel *,
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
forwardBatchJiuge(struct JiugeModel *, forwardBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
...@@ -25,7 +25,7 @@ typedef struct ...@@ -25,7 +25,7 @@ typedef struct
} JiugeAWQMeta; } JiugeAWQMeta;
//////////////////// APIs /////////////////////// //////////////////// APIs ///////////////////////
__C __export struct ModelWeights * __INFINI_C __export struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *, createJiugeAWQWeights(const JiugeAWQMeta *,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
...@@ -34,12 +34,12 @@ createJiugeAWQWeights(const JiugeAWQMeta *, ...@@ -34,12 +34,12 @@ createJiugeAWQWeights(const JiugeAWQMeta *,
/// @param device 协处理器种类 /// @param device 协处理器种类
/// @param ndev 协处理器数量 /// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev /// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeAWQModel * __INFINI_C __export struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *, createJiugeAWQModel(const JiugeAWQMeta *,
const ModelWeights *); const ModelWeights *);
/// @brief 销毁模型 /// @brief 销毁模型
__C __export void __INFINI_C __export void
destroyJiugeAWQModel(struct JiugeAWQModel *); destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @brief 批次推理一轮,并采样出新的 token /// @brief 批次推理一轮,并采样出新的 token
...@@ -53,7 +53,7 @@ destroyJiugeAWQModel(struct JiugeAWQModel *); ...@@ -53,7 +53,7 @@ destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
inferBatchJiugeAWQ(struct JiugeAWQModel *, inferBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -69,7 +69,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *, ...@@ -69,7 +69,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *,
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
forwardBatchJiugeAWQ(struct JiugeAWQModel *, forwardBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
struct ModelWeights; struct ModelWeights;
__C __export void __INFINI_C __export void
loadModelWeight(struct ModelWeights *weights, const char *name, void *data); loadModelWeight(struct ModelWeights *weights, const char *name, void *data);
__C __export void __INFINI_C __export void
loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank); loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank);
#endif // WEIGHTS_LOADER_H #endif // WEIGHTS_LOADER_H
...@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config=DistConfig(1), distributed_config=DistConfig(1),
cache_config=None, cache_config=None,
enable_graph_compiling=False, enable_graph_compiling=False,
attention_backend="default",
): ):
self.config = AutoConfig.from_pretrained(model_path) self.config = AutoConfig.from_pretrained(model_path)
...@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine):
device._underlying.type, device._underlying.type,
cache_config, cache_config,
enable_graph_compiling, enable_graph_compiling,
attention_backend,
) )
self.use_cache = False self.use_cache = False
...@@ -57,6 +59,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -57,6 +59,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=None, past_kv_lengths=None,
total_kv_lengths=None, total_kv_lengths=None,
input_offsets=None, input_offsets=None,
cu_seqlens=None,
block_tables=None, block_tables=None,
slot_mapping=None, slot_mapping=None,
temperature=None, temperature=None,
...@@ -74,6 +77,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -74,6 +77,7 @@ class InferEngine(_infinilm.InferEngine):
) )
input_offsets = input_offsets._underlying if input_offsets is not None else None input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None block_tables = block_tables._underlying if block_tables is not None else None
cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
return infinicore.Tensor( return infinicore.Tensor(
...@@ -85,6 +89,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -85,6 +89,7 @@ class InferEngine(_infinilm.InferEngine):
past_sequence_lengths=past_kv_lengths, past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths, total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
temperature=temperature, temperature=temperature,
...@@ -135,7 +140,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -135,7 +140,7 @@ class InferEngine(_infinilm.InferEngine):
] ]
block_tables = infinicore.from_list( block_tables = infinicore.from_list(
block_tables_list, block_tables_list,
dtype=infinicore.int64, dtype=infinicore.int32,
) )
for iter in range(0, generation_config.max_new_tokens): for iter in range(0, generation_config.max_new_tokens):
...@@ -188,14 +193,17 @@ class InferEngine(_infinilm.InferEngine): ...@@ -188,14 +193,17 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping = None slot_mapping = None
past_kv_lengths = infinicore.from_list( past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len] * batch_size, dtype=infinicore.int32
) )
total_kv_lengths = infinicore.from_list( total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64 [past_seq_len + seq_len] * batch_size, dtype=infinicore.int32
)
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
dtype=infinicore.int32,
) )
input_offsets = infinicore.from_list( input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64 [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32
) )
output_id = self( output_id = self(
...@@ -204,6 +212,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -204,6 +212,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=past_kv_lengths, past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths, total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets, input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables, block_tables=block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
temperature=generation_config.temperature, temperature=generation_config.temperature,
......
#include "../cache.hpp" #include "../cache.hpp"
__C struct KVCache *createKVCache( __INFINI_C struct KVCache *createKVCache(
size_t nlayers, size_t nlayers,
size_t max_len, size_t max_len,
size_t nkvh_, size_t nkvh_,
...@@ -31,7 +31,7 @@ __C struct KVCache *createKVCache( ...@@ -31,7 +31,7 @@ __C struct KVCache *createKVCache(
return cache; return cache;
} }
__C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { __INFINI_C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
auto ndev = kv_cache->k.size(); auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size(); auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType(); auto device = kv_cache->k[0][0]->deviceType();
...@@ -65,7 +65,7 @@ __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { ...@@ -65,7 +65,7 @@ __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
return new_kv_cache; return new_kv_cache;
} }
__C void dropKVCache(KVCache *kv_cache) { __INFINI_C void dropKVCache(KVCache *kv_cache) {
auto ndev = kv_cache->k.size(); auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size(); auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType(); auto device = kv_cache->k[0][0]->deviceType();
......
...@@ -78,7 +78,7 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) { ...@@ -78,7 +78,7 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
} // namespace infinicore::weights } // namespace infinicore::weights
__C void __INFINI_C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) { loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name); std::string name_str(name);
auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_); auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_);
......
...@@ -431,7 +431,7 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc ...@@ -431,7 +431,7 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
} }
} }
__C void __INFINI_C void
inferBatchDeepSeekV3(struct DeepSeekV3Model *model, inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -464,7 +464,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *model, ...@@ -464,7 +464,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
} }
} }
__C void __INFINI_C void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *model, forwardBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -563,14 +563,14 @@ DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3We ...@@ -563,14 +563,14 @@ DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3We
} }
} }
__C struct DeepSeekV3Model * __INFINI_C struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *_meta, createDeepSeekV3Model(const DeepSeekV3Meta *_meta,
const DeepSeekV3Weights *weights) { const DeepSeekV3Weights *weights) {
DeepSeekV3Model *model = new DeepSeekV3Model(_meta, weights); DeepSeekV3Model *model = new DeepSeekV3Model(_meta, weights);
return model; return model;
} }
__C void __INFINI_C void
destroyDeepSeekV3Model(struct DeepSeekV3Model *model) { destroyDeepSeekV3Model(struct DeepSeekV3Model *model) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
......
#include "deepseek_v3_impl.hpp" #include "deepseek_v3_impl.hpp"
__C struct DeepSeekV3Cache * __INFINI_C struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *model) { createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
DeepSeekV3Cache *cache = new DeepSeekV3Cache(); DeepSeekV3Cache *cache = new DeepSeekV3Cache();
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
...@@ -25,7 +25,7 @@ createDeepSeekV3Cache(const struct DeepSeekV3Model *model) { ...@@ -25,7 +25,7 @@ createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
return cache; return cache;
} }
__C void __INFINI_C void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *model, dropDeepSeekV3Cache(const struct DeepSeekV3Model *model,
struct DeepSeekV3Cache *cache) { struct DeepSeekV3Cache *cache) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
......
...@@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = { ...@@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = {
.load_mlp_experts = load_mlp_experts, .load_mlp_experts = load_mlp_experts,
}; };
__C DeepSeekV3Weights * __INFINI_C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta, createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
...@@ -445,7 +445,7 @@ createDeepSeekV3Weights(const DeepSeekV3Meta *meta, ...@@ -445,7 +445,7 @@ createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
return weights; return weights;
}; };
__C DeepSeekV3WeightLoader * __INFINI_C DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader() { createDeepSeekV3WeightLoader() {
return &weight_loader; return &weight_loader;
} }
...@@ -315,7 +315,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, ...@@ -315,7 +315,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
} }
} }
__C void __INFINI_C void
inferBatchJiuge(struct JiugeModel *model, inferBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -348,7 +348,7 @@ inferBatchJiuge(struct JiugeModel *model, ...@@ -348,7 +348,7 @@ inferBatchJiuge(struct JiugeModel *model,
} }
} }
__C void __INFINI_C void
forwardBatchJiuge(struct JiugeModel *model, forwardBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -444,7 +444,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi ...@@ -444,7 +444,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
} }
} }
__C struct JiugeModel * __INFINI_C struct JiugeModel *
createJiugeModel(const JiugeMeta *meta, createJiugeModel(const JiugeMeta *meta,
const JiugeWeights *weights, const JiugeWeights *weights,
infiniDevice_t device, infiniDevice_t device,
...@@ -456,7 +456,7 @@ createJiugeModel(const JiugeMeta *meta, ...@@ -456,7 +456,7 @@ createJiugeModel(const JiugeMeta *meta,
return model; return model;
} }
__C void destroyJiugeModel(struct JiugeModel *model) { __INFINI_C void destroyJiugeModel(struct JiugeModel *model) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
for (size_t idev = 0; idev < ndev; idev++) { for (size_t idev = 0; idev < ndev; idev++) {
......
...@@ -242,7 +242,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc, ...@@ -242,7 +242,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
} }
} }
__C void __INFINI_C void
inferBatchJiugeAWQ(struct JiugeAWQModel *model, inferBatchJiugeAWQ(struct JiugeAWQModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -275,7 +275,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *model, ...@@ -275,7 +275,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *model,
} }
} }
__C void __INFINI_C void
forwardBatchJiugeAWQ(struct JiugeAWQModel *model, forwardBatchJiugeAWQ(struct JiugeAWQModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
...@@ -372,14 +372,14 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh ...@@ -372,14 +372,14 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh
} }
} }
__C struct JiugeAWQModel * __INFINI_C struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *meta, createJiugeAWQModel(const JiugeAWQMeta *meta,
const ModelWeights *weights) { const ModelWeights *weights) {
JiugeAWQModel *model = new JiugeAWQModel(meta, weights); JiugeAWQModel *model = new JiugeAWQModel(meta, weights);
return model; return model;
} }
__C void destroyJiugeAWQModel(struct JiugeAWQModel *model) { __INFINI_C void destroyJiugeAWQModel(struct JiugeAWQModel *model) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
for (size_t idev = 0; idev < ndev; idev++) { for (size_t idev = 0; idev < ndev; idev++) {
......
...@@ -118,7 +118,7 @@ JiugeAWQWeights::JiugeAWQWeights( ...@@ -118,7 +118,7 @@ JiugeAWQWeights::JiugeAWQWeights(
#undef REGISTER_LAYER_QUANT_WEIGHT #undef REGISTER_LAYER_QUANT_WEIGHT
} }
__C struct ModelWeights * __INFINI_C struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *meta, createJiugeAWQWeights(const JiugeAWQMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment