Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
13f98ed3
Commit
13f98ed3
authored
May 26, 2025
by
PanZezhong
Browse files
fix workspace
parent
366d3aef
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
13 deletions
+104
-13
scripts/jiuge.py
scripts/jiuge.py
+27
-3
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+4
-3
src/allocator.hpp
src/allocator.hpp
+1
-2
src/allocator/workspace_allocator.cpp
src/allocator/workspace_allocator.cpp
+5
-1
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+65
-3
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+2
-1
No files found.
scripts/jiuge.py
View file @
13f98ed3
...
@@ -13,6 +13,7 @@ from libinfinicore_infer import (
...
@@ -13,6 +13,7 @@ from libinfinicore_infer import (
DataType
,
DataType
,
DeviceType
,
DeviceType
,
create_jiuge_model
,
create_jiuge_model
,
destroy_jiuge_model
,
create_kv_cache
,
create_kv_cache
,
drop_kv_cache
,
drop_kv_cache
,
infer_batch
,
infer_batch
,
...
@@ -281,6 +282,9 @@ class JiugeForCauslLM:
...
@@ -281,6 +282,9 @@ class JiugeForCauslLM:
for
name_
in
data_
.
keys
():
for
name_
in
data_
.
keys
():
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
return
tensors_
return
tensors_
print
(
"Loading model weights to host..."
)
load_start_time
=
time
.
time
()
with
open
(
os
.
path
.
join
(
model_dir_path
,
"config.json"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
model_dir_path
,
"config.json"
),
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
...
@@ -293,7 +297,12 @@ class JiugeForCauslLM:
...
@@ -293,7 +297,12 @@ class JiugeForCauslLM:
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
)
)
elif
"fm9g"
==
config
[
"model_type"
]:
elif
"fm9g"
==
config
[
"model_type"
]:
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
if
any
(
file
.
suffix
==
".safetensors"
for
file
in
Path
(
model_dir_path
).
iterdir
()):
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
else
:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
...
@@ -317,6 +326,12 @@ class JiugeForCauslLM:
...
@@ -317,6 +326,12 @@ class JiugeForCauslLM:
else
:
else
:
raise
ValueError
(
"Unsupported model architecture"
)
raise
ValueError
(
"Unsupported model architecture"
)
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Creating model on
{
ndev
}
devices..."
)
load_start_time
=
time
.
time
()
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
self
.
model_instance
=
create_jiuge_model
(
self
.
model_instance
=
create_jiuge_model
(
byref
(
self
.
meta
),
byref
(
self
.
meta
),
...
@@ -325,18 +340,21 @@ class JiugeForCauslLM:
...
@@ -325,18 +340,21 @@ class JiugeForCauslLM:
ndev
,
ndev
,
dev_ids
,
dev_ids
,
)
)
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
def
infer
(
self
,
input_list
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
def
infer
(
self
,
input_list
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
pass
pass
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
input_content
=
self
.
tokenizer
.
apply_chat_template
(
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
add_generation_prompt
=
True
,
add_generation_prompt
=
True
,
tokenize
=
False
,
tokenize
=
False
,
)
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
ntok
=
len
(
tokens
)
ntok
=
len
(
tokens
)
nreq
=
1
nreq
=
1
...
@@ -367,6 +385,7 @@ class JiugeForCauslLM:
...
@@ -367,6 +385,7 @@ class JiugeForCauslLM:
)
)
steps
+=
1
steps
+=
1
output_tokens
=
list
(
ans
)
output_tokens
=
list
(
ans
)
end_time
=
time
.
time
()
output_str
=
(
output_str
=
(
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"▁"
,
" "
)
...
@@ -380,7 +399,7 @@ class JiugeForCauslLM:
...
@@ -380,7 +399,7 @@ class JiugeForCauslLM:
ntok
=
1
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
end_time
=
time
.
time
()
if
step_i
>
0
:
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
total_time
+=
end_time
-
start_time
...
@@ -390,6 +409,10 @@ class JiugeForCauslLM:
...
@@ -390,6 +409,10 @@ class JiugeForCauslLM:
for
kv_cache
in
kv_caches
:
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
return
output_content
,
avg_time
return
output_content
,
avg_time
def
destroy_model_instance
(
self
):
destroy_jiuge_model
(
self
.
model_instance
)
print
(
"Model destroyed"
)
def
test
():
def
test
():
...
@@ -421,6 +444,7 @@ def test():
...
@@ -421,6 +444,7 @@ def test():
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
model
.
generate
(
"山东最高的山是?"
,
500
)
model
.
generate
(
"山东最高的山是?"
,
500
)
model
.
destroy_model_instance
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
scripts/libinfinicore_infer.py
View file @
13f98ed3
...
@@ -92,12 +92,12 @@ def __open_library__():
...
@@ -92,12 +92,12 @@ def __open_library__():
c_int
,
# int ndev
c_int
,
# int ndev
POINTER
(
c_int
),
# int const *dev_ids
POINTER
(
c_int
),
# int const *dev_ids
]
]
lib
.
destroyJiugeModel
.
argtypes
=
[
POINTER
(
JiugeModel
)]
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
createKVCache
.
restype
=
POINTER
(
KVCache
)
lib
.
dropKVCache
.
argtypes
=
[
ctypes
.
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
dropKVCache
.
argtypes
=
[
POINTER
(
JiugeModel
),
POINTER
(
KVCache
)]
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
restype
=
None
lib
.
inferBatch
.
argtypes
=
[
lib
.
inferBatch
.
argtypes
=
[
ctypes
.
POINTER
(
JiugeModel
),
# struct JiugeModel const *
POINTER
(
JiugeModel
),
# struct JiugeModel const *
POINTER
(
c_uint
),
# unsigned int const *tokens
POINTER
(
c_uint
),
# unsigned int const *tokens
c_uint
,
# unsigned int ntok
c_uint
,
# unsigned int ntok
POINTER
(
c_uint
),
# unsigned int const *req_lens
POINTER
(
c_uint
),
# unsigned int const *req_lens
...
@@ -116,6 +116,7 @@ def __open_library__():
...
@@ -116,6 +116,7 @@ def __open_library__():
LIB
=
__open_library__
()
LIB
=
__open_library__
()
create_jiuge_model
=
LIB
.
createJiugeModel
create_jiuge_model
=
LIB
.
createJiugeModel
destroy_jiuge_model
=
LIB
.
destroyJiugeModel
create_kv_cache
=
LIB
.
createKVCache
create_kv_cache
=
LIB
.
createKVCache
drop_kv_cache
=
LIB
.
dropKVCache
drop_kv_cache
=
LIB
.
dropKVCache
infer_batch
=
LIB
.
inferBatch
infer_batch
=
LIB
.
inferBatch
src/allocator.hpp
View file @
13f98ed3
...
@@ -13,8 +13,7 @@ class WorkspaceAllocator : public AllocatorBase {
...
@@ -13,8 +13,7 @@ class WorkspaceAllocator : public AllocatorBase {
private:
private:
void
*
_memory
;
void
*
_memory
;
size_t
_total_size
;
size_t
_total_size
;
size_t
_used_size
;
size_t
_align
;
size_t
_align
=
256
;
public:
public:
WorkspaceAllocator
(
size_t
intial_size
,
size_t
align
=
256
);
WorkspaceAllocator
(
size_t
intial_size
,
size_t
align
=
256
);
...
...
src/allocator/workspace_allocator.cpp
View file @
13f98ed3
...
@@ -14,6 +14,8 @@ inline void *allocate(size_t size_) {
...
@@ -14,6 +14,8 @@ inline void *allocate(size_t size_) {
WorkspaceAllocator
::
WorkspaceAllocator
(
size_t
initial_size_
,
size_t
align
)
{
WorkspaceAllocator
::
WorkspaceAllocator
(
size_t
initial_size_
,
size_t
align
)
{
_align
=
align
;
_align
=
align
;
_total_size
=
0
;
_memory
=
nullptr
;
if
(
initial_size_
>
0
)
{
if
(
initial_size_
>
0
)
{
_total_size
=
aligned_size
(
initial_size_
,
_align
);
_total_size
=
aligned_size
(
initial_size_
,
_align
);
_memory
=
allocate
(
_total_size
);
_memory
=
allocate
(
_total_size
);
...
@@ -23,9 +25,10 @@ WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
...
@@ -23,9 +25,10 @@ WorkspaceAllocator::WorkspaceAllocator(size_t initial_size_, size_t align) {
void
*
WorkspaceAllocator
::
alloc
(
size_t
new_size
)
{
void
*
WorkspaceAllocator
::
alloc
(
size_t
new_size
)
{
if
(
_total_size
<
new_size
)
{
if
(
_total_size
<
new_size
)
{
if
(
_total_size
!=
0
)
{
if
(
_total_size
!=
0
)
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
RUN_INFINI
(
infinirtFree
(
_memory
));
RUN_INFINI
(
infinirtFree
(
_memory
));
}
}
_total_size
=
aligned_size
(
new_size
*
3
/
2
,
_align
);
_total_size
=
aligned_size
(
new_size
,
_align
);
_memory
=
allocate
(
_total_size
);
_memory
=
allocate
(
_total_size
);
}
}
return
_memory
;
return
_memory
;
...
@@ -36,6 +39,7 @@ void WorkspaceAllocator::release(void *ptr) {
...
@@ -36,6 +39,7 @@ void WorkspaceAllocator::release(void *ptr) {
WorkspaceAllocator
::~
WorkspaceAllocator
()
{
WorkspaceAllocator
::~
WorkspaceAllocator
()
{
if
(
_memory
!=
nullptr
)
{
if
(
_memory
!=
nullptr
)
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
RUN_INFINI
(
infinirtFree
(
_memory
));
RUN_INFINI
(
infinirtFree
(
_memory
));
}
}
}
}
\ No newline at end of file
src/models/jiuge/jiuge.cpp
View file @
13f98ed3
...
@@ -61,6 +61,52 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
...
@@ -61,6 +61,52 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
comm
,
comm
,
std
::
make_unique
<
WorkspaceAllocator
>
(
0
),
std
::
make_unique
<
WorkspaceAllocator
>
(
0
),
};
};
RUN_INFINI
(
infinirtDeviceSynchronize
());
}
void
releaseDeviceResource
(
DeviceResource
&
res
)
{
infinirtDeviceSynchronize
();
// Release individual Tensors
res
.
w_in_embd
.
reset
();
res
.
w_out_norm
.
reset
();
res
.
w_out_embd
.
reset
();
res
.
sin_table
.
reset
();
res
.
cos_table
.
reset
();
for
(
auto
&
t
:
res
.
w_attn_norm
)
{
t
.
reset
();
}
res
.
w_attn_norm
.
clear
();
for
(
auto
&
t
:
res
.
w_attn_qkv
)
{
t
.
reset
();
}
res
.
w_attn_qkv
.
clear
();
for
(
auto
&
t
:
res
.
b_attn_qkv
)
{
t
.
reset
();
}
res
.
b_attn_qkv
.
clear
();
for
(
auto
&
t
:
res
.
w_attn_out
)
{
t
.
reset
();
}
res
.
w_attn_out
.
clear
();
for
(
auto
&
t
:
res
.
w_ffn_norm
)
{
t
.
reset
();
}
res
.
w_ffn_norm
.
clear
();
for
(
auto
&
t
:
res
.
w_ffn_gate_up
)
{
t
.
reset
();
}
res
.
w_ffn_gate_up
.
clear
();
for
(
auto
&
t
:
res
.
w_ffn_down
)
{
t
.
reset
();
}
res
.
w_ffn_down
.
clear
();
res
.
workspace_allocator
.
reset
();
infiniopDestroyHandle
(
res
.
handle
);
res
.
handle
=
nullptr
;
infinirtStreamDestroy
(
res
.
stream
);
res
.
stream
=
nullptr
;
infinicclCommDestroy
(
res
.
comm
);
res
.
comm
=
nullptr
;
}
}
void
inferDeviceBatch
(
const
JiugeMeta
&
meta
,
DeviceResource
&
rsrc
,
void
inferDeviceBatch
(
const
JiugeMeta
&
meta
,
DeviceResource
&
rsrc
,
...
@@ -291,6 +337,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -291,6 +337,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinicclAllReduce
(
RUN_INFINI
(
infinicclAllReduce
(
logits_in
->
data
(),
logits_in
->
data
(),
ntok
*
d
,
dt_logits
,
logits_in
->
data
(),
logits_in
->
data
(),
ntok
*
d
,
dt_logits
,
INFINICCL_SUM
,
rsrc
.
comm
,
stream
));
INFINICCL_SUM
,
rsrc
.
comm
,
stream
));
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
}
}
// 2. FFN
// 2. FFN
// rms_norm
// rms_norm
...
@@ -315,6 +362,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -315,6 +362,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinicclAllReduce
(
RUN_INFINI
(
infinicclAllReduce
(
logits_in
->
data
(),
logits_in
->
data
(),
ntok
*
d
,
dt_logits
,
logits_in
->
data
(),
logits_in
->
data
(),
ntok
*
d
,
dt_logits
,
INFINICCL_SUM
,
rsrc
.
comm
,
stream
));
INFINICCL_SUM
,
rsrc
.
comm
,
stream
));
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
}
}
}
}
// Sample and Output
// Sample and Output
...
@@ -408,10 +456,20 @@ inferBatch(struct JiugeModel *model,
...
@@ -408,10 +456,20 @@ inferBatch(struct JiugeModel *model,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
void
launchDevice
(
const
JiugeMeta
&
meta
,
const
JiugeWeights
*
weights
,
DeviceResource
*
rsrc
,
InferState
&
state
,
InferRequest
&
req
,
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
infiniDevice_t
device
,
int
idev
,
int
ndev
,
int
dev_id
,
infinicclComm_t
comm
)
{
// Create Device Resource
createDeviceResource
(
rsrc
,
&
meta
,
weights
,
device
,
idev
,
ndev
,
dev_id
,
comm
);
createDeviceResource
(
rsrc
,
&
meta
,
weights
,
device
,
idev
,
ndev
,
dev_id
,
comm
);
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
state
.
mtx
);
state
.
loaded
=
true
;
lock
.
unlock
();
state
.
cv_load
.
notify_one
();
}
// Infer Loop
while
(
true
)
{
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
state
.
mtx
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
state
.
mtx
);
state
.
cv_start
.
wait
(
lock
,
[
&
]
{
return
state
.
proceed
||
state
.
exit_flag
;
});
state
.
cv_start
.
wait
(
lock
,
[
&
]
{
return
state
.
proceed
||
state
.
exit_flag
;
});
// quit if exit_flag is set
if
(
state
.
exit_flag
)
{
if
(
state
.
exit_flag
)
{
break
;
break
;
}
}
...
@@ -423,9 +481,8 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
...
@@ -423,9 +481,8 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
state
.
cv_done
.
notify_one
();
state
.
cv_done
.
notify_one
();
}
}
infiniopDestroyHandle
(
rsrc
->
handle
);
// Clean-Up
infinirtStreamDestroy
(
rsrc
->
stream
);
releaseDeviceResource
(
*
rsrc
);
infinicclCommDestroy
(
rsrc
->
comm
);
}
}
JiugeModel
::
JiugeModel
(
const
JiugeMeta
*
_meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device_
,
std
::
vector
<
int
>
device_ids
)
:
meta
(
*
_meta
)
{
JiugeModel
::
JiugeModel
(
const
JiugeMeta
*
_meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device_
,
std
::
vector
<
int
>
device_ids
)
:
meta
(
*
_meta
)
{
...
@@ -444,6 +501,11 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
...
@@ -444,6 +501,11 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
for
(
int
i
=
0
;
i
<
ndev
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ndev
;
i
++
)
{
threads
[
i
]
=
std
::
thread
(
launchDevice
,
std
::
cref
(
meta
),
weights
,
&
dev_resources
[
i
],
std
::
ref
(
states
[
i
]),
std
::
ref
(
req
),
device
,
i
,
ndev
,
dev_ids
[
i
],
comms
[
i
]);
threads
[
i
]
=
std
::
thread
(
launchDevice
,
std
::
cref
(
meta
),
weights
,
&
dev_resources
[
i
],
std
::
ref
(
states
[
i
]),
std
::
ref
(
req
),
device
,
i
,
ndev
,
dev_ids
[
i
],
comms
[
i
]);
}
}
for
(
int
i
=
0
;
i
<
ndev
;
i
++
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
states
[
i
].
mtx
);
states
[
i
].
cv_load
.
wait
(
lock
,
[
&
]
{
return
states
[
i
].
loaded
;
});
lock
.
unlock
();
}
}
}
__C
struct
JiugeModel
*
__C
struct
JiugeModel
*
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
13f98ed3
...
@@ -32,7 +32,8 @@ struct DeviceResource {
...
@@ -32,7 +32,8 @@ struct DeviceResource {
struct
InferState
{
struct
InferState
{
std
::
mutex
mtx
;
std
::
mutex
mtx
;
std
::
condition_variable
cv_start
,
cv_done
;
std
::
condition_variable
cv_load
,
cv_start
,
cv_done
;
bool
loaded
=
false
;
bool
proceed
=
false
;
bool
proceed
=
false
;
bool
exit_flag
=
false
;
bool
exit_flag
=
false
;
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment