Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
d09de04c
Unverified
Commit
d09de04c
authored
Mar 06, 2026
by
thatPepe
Committed by
GitHub
Mar 06, 2026
Browse files
Merge pull request #250 from InfiniTensor/issue/248
Issue/248 support flash-attention
parents
f67956fe
5dc85bf4
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
78 additions
and
48 deletions
+78
-48
examples/bench.py
examples/bench.py
+10
-0
examples/jiuge.py
examples/jiuge.py
+11
-0
include/infinicore_infer/cache.h
include/infinicore_infer/cache.h
+3
-3
include/infinicore_infer/models/deepseek.h
include/infinicore_infer/models/deepseek.h
+8
-8
include/infinicore_infer/models/jiuge.h
include/infinicore_infer/models/jiuge.h
+4
-4
include/infinicore_infer/models/jiuge_awq.h
include/infinicore_infer/models/jiuge_awq.h
+5
-5
include/infinicore_infer/weights_loader.h
include/infinicore_infer/weights_loader.h
+2
-2
python/infinilm/infer_engine.py
python/infinilm/infer_engine.py
+14
-5
src/cache_manager/kvcache.cpp
src/cache_manager/kvcache.cpp
+3
-3
src/dataloader/weights_loader.cpp
src/dataloader/weights_loader.cpp
+1
-1
src/models/deepseek_v3/deepseek_v3.cpp
src/models/deepseek_v3/deepseek_v3.cpp
+4
-4
src/models/deepseek_v3/deepseek_v3_cache.cpp
src/models/deepseek_v3/deepseek_v3_cache.cpp
+2
-2
src/models/deepseek_v3/deepseek_v3_weight.cpp
src/models/deepseek_v3/deepseek_v3_weight.cpp
+2
-2
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+4
-4
src/models/jiuge_awq/jiuge_awq.cpp
src/models/jiuge_awq/jiuge_awq.cpp
+4
-4
src/models/jiuge_awq/jiuge_awq_weight.cpp
src/models/jiuge_awq/jiuge_awq_weight.cpp
+1
-1
No files found.
examples/bench.py
View file @
d09de04c
...
...
@@ -252,6 +252,13 @@ def get_args():
action
=
"store_true"
,
help
=
"Perform a warmup run before benchmarking/inference."
,
)
parser
.
add_argument
(
"--attn"
,
type
=
str
,
default
=
"default"
,
choices
=
[
"default"
,
"flash-attn"
],
help
=
"attention backend to use: 'default' or 'flash-attn'"
,
)
return
parser
.
parse_args
()
...
...
@@ -278,6 +285,7 @@ class TestModel:
skip_load
=
False
,
cache_config
=
None
,
enable_graph
=
False
,
attn_backend
=
"default"
,
)
->
None
:
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
...
...
@@ -289,6 +297,7 @@ class TestModel:
distributed_config
=
DistConfig
(
tp
),
cache_config
=
cache_config
,
enable_graph_compiling
=
enable_graph
,
attention_backend
=
attn_backend
,
)
# ---------------------------------------------------------------------------- #
...
...
@@ -461,6 +470,7 @@ if __name__ == "__main__":
skip_load
=
skip_load
,
cache_config
=
cache_config
,
enable_graph
=
enable_graph
,
attn_backend
=
args
.
attn
,
)
# ---------------------------------------------------------------------------- #
...
...
examples/jiuge.py
View file @
d09de04c
...
...
@@ -142,6 +142,14 @@ def get_args():
help
=
"sampling temperature"
,
)
parser
.
add_argument
(
"--attn"
,
type
=
str
,
default
=
"default"
,
choices
=
[
"default"
,
"flash-attn"
],
help
=
"attention backend to use: 'default' or 'flash-attn'"
,
)
return
parser
.
parse_args
()
...
...
@@ -156,6 +164,7 @@ def test(
top_k
=
1
,
top_p
=
1.0
,
temperature
=
1.0
,
attn_backend
=
"default"
,
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
...
...
@@ -166,6 +175,7 @@ def test(
device
=
infini_device
,
distributed_config
=
DistConfig
(
tp
),
enable_graph_compiling
=
enable_graph
,
attention_backend
=
attn_backend
,
)
# ---------------------------------------------------------------------------- #
# Load Weights
...
...
@@ -333,4 +343,5 @@ if __name__ == "__main__":
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
temperature
=
args
.
temperature
,
attn_backend
=
args
.
attn
,
)
include/infinicore_infer/cache.h
View file @
d09de04c
...
...
@@ -3,7 +3,7 @@
#include <infinirt.h>
__C
__export
struct
KVCache
*
createKVCache
(
__
INFINI_
C
__export
struct
KVCache
*
createKVCache
(
size_t
nlayers
,
size_t
max_len
,
size_t
nkvh_
,
...
...
@@ -14,8 +14,8 @@ __C __export struct KVCache *createKVCache(
int
*
dev_ids
,
size_t
ndev
);
__C
__export
struct
KVCache
*
duplicateKVCache
(
const
KVCache
*
kv_cache
,
size_t
seq_len
);
__
INFINI_
C
__export
struct
KVCache
*
duplicateKVCache
(
const
KVCache
*
kv_cache
,
size_t
seq_len
);
__C
__export
void
dropKVCache
(
KVCache
*
kv_cache
);
__
INFINI_
C
__export
void
dropKVCache
(
KVCache
*
kv_cache
);
#endif
/* CACHE_H */
include/infinicore_infer/models/deepseek.h
View file @
d09de04c
...
...
@@ -103,26 +103,26 @@ typedef struct {
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C
__export
struct
DeepSeekV3Model
*
__
INFINI_
C
__export
struct
DeepSeekV3Model
*
createDeepSeekV3Model
(
const
DeepSeekV3Meta
*
,
const
DeepSeekV3Weights
*
);
__C
DeepSeekV3Weights
*
__
INFINI_
C
DeepSeekV3Weights
*
createDeepSeekV3Weights
(
const
DeepSeekV3Meta
*
meta
,
infiniDevice_t
device
,
int
ndev
,
const
int
*
dev_ids
);
__C
__export
DeepSeekV3WeightLoader
*
__
INFINI_
C
__export
DeepSeekV3WeightLoader
*
createDeepSeekV3WeightLoader
();
/// @brief 销毁模型
__C
__export
void
destroyDeepSeekV3Model
(
struct
DeepSeekV3Model
*
);
__
INFINI_
C
__export
void
destroyDeepSeekV3Model
(
struct
DeepSeekV3Model
*
);
__C
__export
struct
DeepSeekV3Cache
*
__
INFINI_
C
__export
struct
DeepSeekV3Cache
*
createDeepSeekV3Cache
(
const
struct
DeepSeekV3Model
*
);
__C
__export
void
__
INFINI_
C
__export
void
dropDeepSeekV3Cache
(
const
struct
DeepSeekV3Model
*
,
struct
DeepSeekV3Cache
*
);
...
...
@@ -137,7 +137,7 @@ dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
inferBatchDeepSeekV3
(
struct
DeepSeekV3Model
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -153,7 +153,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
forwardBatchDeepSeekV3
(
struct
DeepSeekV3Model
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
include/infinicore_infer/models/jiuge.h
View file @
d09de04c
...
...
@@ -54,7 +54,7 @@ typedef struct
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C
__export
struct
JiugeModel
*
__
INFINI_
C
__export
struct
JiugeModel
*
createJiugeModel
(
const
JiugeMeta
*
,
const
JiugeWeights
*
,
infiniDevice_t
device
,
...
...
@@ -62,7 +62,7 @@ createJiugeModel(const JiugeMeta *,
const
int
*
dev_ids
);
/// @brief 销毁模型
__C
__export
void
__
INFINI_
C
__export
void
destroyJiugeModel
(
struct
JiugeModel
*
);
/// @brief 批次推理一轮,并采样出新的 token
...
...
@@ -76,7 +76,7 @@ destroyJiugeModel(struct JiugeModel *);
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
inferBatchJiuge
(
struct
JiugeModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -92,7 +92,7 @@ inferBatchJiuge(struct JiugeModel *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
forwardBatchJiuge
(
struct
JiugeModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
include/infinicore_infer/models/jiuge_awq.h
View file @
d09de04c
...
...
@@ -25,7 +25,7 @@ typedef struct
}
JiugeAWQMeta
;
//////////////////// APIs ///////////////////////
__C
__export
struct
ModelWeights
*
__
INFINI_
C
__export
struct
ModelWeights
*
createJiugeAWQWeights
(
const
JiugeAWQMeta
*
,
infiniDevice_t
device
,
int
ndev
,
...
...
@@ -34,12 +34,12 @@ createJiugeAWQWeights(const JiugeAWQMeta *,
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C
__export
struct
JiugeAWQModel
*
__
INFINI_
C
__export
struct
JiugeAWQModel
*
createJiugeAWQModel
(
const
JiugeAWQMeta
*
,
const
ModelWeights
*
);
/// @brief 销毁模型
__C
__export
void
__
INFINI_
C
__export
void
destroyJiugeAWQModel
(
struct
JiugeAWQModel
*
);
/// @brief 批次推理一轮,并采样出新的 token
...
...
@@ -53,7 +53,7 @@ destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
inferBatchJiugeAWQ
(
struct
JiugeAWQModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -69,7 +69,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C
__export
void
__
INFINI_
C
__export
void
forwardBatchJiugeAWQ
(
struct
JiugeAWQModel
*
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
include/infinicore_infer/weights_loader.h
View file @
d09de04c
...
...
@@ -5,10 +5,10 @@
struct
ModelWeights
;
__C
__export
void
__
INFINI_
C
__export
void
loadModelWeight
(
struct
ModelWeights
*
weights
,
const
char
*
name
,
void
*
data
);
__C
__export
void
__
INFINI_
C
__export
void
loadModelWeightDistributed
(
struct
ModelWeights
*
weights
,
const
char
*
name
,
void
*
data
,
int
*
ranks
,
int
nrank
);
#endif // WEIGHTS_LOADER_H
python/infinilm/infer_engine.py
View file @
d09de04c
...
...
@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config
=
DistConfig
(
1
),
cache_config
=
None
,
enable_graph_compiling
=
False
,
attention_backend
=
"default"
,
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_path
)
...
...
@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine):
device
.
_underlying
.
type
,
cache_config
,
enable_graph_compiling
,
attention_backend
,
)
self
.
use_cache
=
False
...
...
@@ -57,6 +59,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths
=
None
,
total_kv_lengths
=
None
,
input_offsets
=
None
,
cu_seqlens
=
None
,
block_tables
=
None
,
slot_mapping
=
None
,
temperature
=
None
,
...
...
@@ -74,6 +77,7 @@ class InferEngine(_infinilm.InferEngine):
)
input_offsets
=
input_offsets
.
_underlying
if
input_offsets
is
not
None
else
None
block_tables
=
block_tables
.
_underlying
if
block_tables
is
not
None
else
None
cu_seqlens
=
cu_seqlens
.
_underlying
if
cu_seqlens
is
not
None
else
None
slot_mapping
=
slot_mapping
.
_underlying
if
slot_mapping
is
not
None
else
None
return
infinicore
.
Tensor
(
...
...
@@ -85,6 +89,7 @@ class InferEngine(_infinilm.InferEngine):
past_sequence_lengths
=
past_kv_lengths
,
total_sequence_lengths
=
total_kv_lengths
,
input_offsets
=
input_offsets
,
cu_seqlens
=
cu_seqlens
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
temperature
=
temperature
,
...
...
@@ -135,7 +140,7 @@ class InferEngine(_infinilm.InferEngine):
]
block_tables
=
infinicore
.
from_list
(
block_tables_list
,
dtype
=
infinicore
.
int
64
,
dtype
=
infinicore
.
int
32
,
)
for
iter
in
range
(
0
,
generation_config
.
max_new_tokens
):
...
...
@@ -188,14 +193,17 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping
=
None
past_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int
64
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int
32
)
total_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
+
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
[
past_seq_len
+
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int32
)
cu_seqlens
=
infinicore
.
from_list
(
[(
past_seq_len
+
seq_len
)
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int32
,
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int
64
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int
32
)
output_id
=
self
(
...
...
@@ -204,6 +212,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths
=
past_kv_lengths
,
total_kv_lengths
=
total_kv_lengths
,
input_offsets
=
input_offsets
,
cu_seqlens
=
cu_seqlens
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
temperature
=
generation_config
.
temperature
,
...
...
src/cache_manager/kvcache.cpp
View file @
d09de04c
#include "../cache.hpp"
__C
struct
KVCache
*
createKVCache
(
__
INFINI_
C
struct
KVCache
*
createKVCache
(
size_t
nlayers
,
size_t
max_len
,
size_t
nkvh_
,
...
...
@@ -31,7 +31,7 @@ __C struct KVCache *createKVCache(
return
cache
;
}
__C
struct
KVCache
*
duplicateKVCache
(
const
KVCache
*
kv_cache
,
size_t
seq_len
)
{
__
INFINI_
C
struct
KVCache
*
duplicateKVCache
(
const
KVCache
*
kv_cache
,
size_t
seq_len
)
{
auto
ndev
=
kv_cache
->
k
.
size
();
auto
nlayers
=
kv_cache
->
k
[
0
].
size
();
auto
device
=
kv_cache
->
k
[
0
][
0
]
->
deviceType
();
...
...
@@ -65,7 +65,7 @@ __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
return
new_kv_cache
;
}
__C
void
dropKVCache
(
KVCache
*
kv_cache
)
{
__
INFINI_
C
void
dropKVCache
(
KVCache
*
kv_cache
)
{
auto
ndev
=
kv_cache
->
k
.
size
();
auto
nlayers
=
kv_cache
->
k
[
0
].
size
();
auto
device
=
kv_cache
->
k
[
0
][
0
]
->
deviceType
();
...
...
src/dataloader/weights_loader.cpp
View file @
d09de04c
...
...
@@ -78,7 +78,7 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
}
// namespace infinicore::weights
__C
void
__
INFINI_
C
void
loadModelWeight
(
struct
ModelWeights
*
weights_
,
const
char
*
name
,
void
*
data
)
{
std
::
string
name_str
(
name
);
auto
weights
=
reinterpret_cast
<
infinicore
::
weights
::
Loader
*>
(
weights_
);
...
...
src/models/deepseek_v3/deepseek_v3.cpp
View file @
d09de04c
...
...
@@ -431,7 +431,7 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
}
}
__C
void
__
INFINI_
C
void
inferBatchDeepSeekV3
(
struct
DeepSeekV3Model
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -464,7 +464,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
}
}
__C
void
__
INFINI_
C
void
forwardBatchDeepSeekV3
(
struct
DeepSeekV3Model
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -563,14 +563,14 @@ DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3We
}
}
__C
struct
DeepSeekV3Model
*
__
INFINI_
C
struct
DeepSeekV3Model
*
createDeepSeekV3Model
(
const
DeepSeekV3Meta
*
_meta
,
const
DeepSeekV3Weights
*
weights
)
{
DeepSeekV3Model
*
model
=
new
DeepSeekV3Model
(
_meta
,
weights
);
return
model
;
}
__C
void
__
INFINI_
C
void
destroyDeepSeekV3Model
(
struct
DeepSeekV3Model
*
model
)
{
auto
ndev
=
model
->
dev_resources
.
size
();
...
...
src/models/deepseek_v3/deepseek_v3_cache.cpp
View file @
d09de04c
#include "deepseek_v3_impl.hpp"
__C
struct
DeepSeekV3Cache
*
__
INFINI_
C
struct
DeepSeekV3Cache
*
createDeepSeekV3Cache
(
const
struct
DeepSeekV3Model
*
model
)
{
DeepSeekV3Cache
*
cache
=
new
DeepSeekV3Cache
();
auto
ndev
=
model
->
dev_resources
.
size
();
...
...
@@ -25,7 +25,7 @@ createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
return
cache
;
}
__C
void
__
INFINI_
C
void
dropDeepSeekV3Cache
(
const
struct
DeepSeekV3Model
*
model
,
struct
DeepSeekV3Cache
*
cache
)
{
auto
ndev
=
model
->
dev_resources
.
size
();
...
...
src/models/deepseek_v3/deepseek_v3_weight.cpp
View file @
d09de04c
...
...
@@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = {
.
load_mlp_experts
=
load_mlp_experts
,
};
__C
DeepSeekV3Weights
*
__
INFINI_
C
DeepSeekV3Weights
*
createDeepSeekV3Weights
(
const
DeepSeekV3Meta
*
meta
,
infiniDevice_t
device
,
int
ndev
,
...
...
@@ -445,7 +445,7 @@ createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
return
weights
;
};
__C
DeepSeekV3WeightLoader
*
__
INFINI_
C
DeepSeekV3WeightLoader
*
createDeepSeekV3WeightLoader
()
{
return
&
weight_loader
;
}
src/models/jiuge/jiuge.cpp
View file @
d09de04c
...
...
@@ -315,7 +315,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
}
}
__C
void
__
INFINI_
C
void
inferBatchJiuge
(
struct
JiugeModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -348,7 +348,7 @@ inferBatchJiuge(struct JiugeModel *model,
}
}
__C
void
__
INFINI_
C
void
forwardBatchJiuge
(
struct
JiugeModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -444,7 +444,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
}
}
__C
struct
JiugeModel
*
__
INFINI_
C
struct
JiugeModel
*
createJiugeModel
(
const
JiugeMeta
*
meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device
,
...
...
@@ -456,7 +456,7 @@ createJiugeModel(const JiugeMeta *meta,
return
model
;
}
__C
void
destroyJiugeModel
(
struct
JiugeModel
*
model
)
{
__
INFINI_
C
void
destroyJiugeModel
(
struct
JiugeModel
*
model
)
{
auto
ndev
=
model
->
dev_resources
.
size
();
for
(
size_t
idev
=
0
;
idev
<
ndev
;
idev
++
)
{
...
...
src/models/jiuge_awq/jiuge_awq.cpp
View file @
d09de04c
...
...
@@ -242,7 +242,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
}
}
__C
void
__
INFINI_
C
void
inferBatchJiugeAWQ
(
struct
JiugeAWQModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -275,7 +275,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *model,
}
}
__C
void
__
INFINI_
C
void
forwardBatchJiugeAWQ
(
struct
JiugeAWQModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
...
...
@@ -372,14 +372,14 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh
}
}
__C
struct
JiugeAWQModel
*
__
INFINI_
C
struct
JiugeAWQModel
*
createJiugeAWQModel
(
const
JiugeAWQMeta
*
meta
,
const
ModelWeights
*
weights
)
{
JiugeAWQModel
*
model
=
new
JiugeAWQModel
(
meta
,
weights
);
return
model
;
}
__C
void
destroyJiugeAWQModel
(
struct
JiugeAWQModel
*
model
)
{
__
INFINI_
C
void
destroyJiugeAWQModel
(
struct
JiugeAWQModel
*
model
)
{
auto
ndev
=
model
->
dev_resources
.
size
();
for
(
size_t
idev
=
0
;
idev
<
ndev
;
idev
++
)
{
...
...
src/models/jiuge_awq/jiuge_awq_weight.cpp
View file @
d09de04c
...
...
@@ -118,7 +118,7 @@ JiugeAWQWeights::JiugeAWQWeights(
#undef REGISTER_LAYER_QUANT_WEIGHT
}
__C
struct
ModelWeights
*
__
INFINI_
C
struct
ModelWeights
*
createJiugeAWQWeights
(
const
JiugeAWQMeta
*
meta
,
infiniDevice_t
device
,
int
ndev
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment