Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
80909bee
Unverified
Commit
80909bee
authored
Oct 22, 2025
by
pengcheng888
Committed by
GitHub
Oct 22, 2025
Browse files
Issue/60 -main 修复输出token乱码并适配了qwen3模型
parent
753a4f60
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
115 additions
and
45 deletions
+115
-45
include/infinicore_infer/models/jiuge.h
include/infinicore_infer/models/jiuge.h
+4
-0
scripts/deepseek.py
scripts/deepseek.py
+1
-5
scripts/jiuge.py
scripts/jiuge.py
+57
-7
scripts/jiuge_awq.py
scripts/jiuge_awq.py
+0
-5
scripts/launch_server.py
scripts/launch_server.py
+3
-10
scripts/libinfinicore_infer/jiuge.py
scripts/libinfinicore_infer/jiuge.py
+2
-0
scripts/test_ceval.py
scripts/test_ceval.py
+1
-5
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+30
-12
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+1
-1
src/models/jiuge/jiuge_weight.hpp
src/models/jiuge/jiuge_weight.hpp
+16
-0
No files found.
include/infinicore_infer/models/jiuge.h
View file @
80909bee
...
@@ -35,6 +35,10 @@ typedef struct
...
@@ -35,6 +35,10 @@ typedef struct
const
void
*
const
*
attn_qkv
;
const
void
*
const
*
attn_qkv
;
// nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
// nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
const
void
*
const
*
attn_qkv_b
;
const
void
*
const
*
attn_qkv_b
;
// nlayer * [dh]
const
void
*
const
*
attn_q_norm
;
// nlayer * [dh]
const
void
*
const
*
attn_k_norm
;
// nlayer * [ndev, d, nkvh / ndev * dh]
// nlayer * [ndev, d, nkvh / ndev * dh]
const
void
*
const
*
attn_o
;
const
void
*
const
*
attn_o
;
// nlayer * [d]
// nlayer * [d]
...
...
scripts/deepseek.py
View file @
80909bee
...
@@ -662,11 +662,7 @@ class DeepSeekV3ForCauslLM:
...
@@ -662,11 +662,7 @@ class DeepSeekV3ForCauslLM:
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
steps
+=
1
output_str
=
(
output_str
=
self
.
tokenizer
.
decode
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output_content
+=
output_str
output_content
+=
output_str
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
...
...
scripts/jiuge.py
View file @
80909bee
...
@@ -58,6 +58,12 @@ class LlamaWeightsNaming:
...
@@ -58,6 +58,12 @@ class LlamaWeightsNaming:
def
attn_v_b
(
self
,
i
):
def
attn_v_b
(
self
,
i
):
return
f
"model.layers.
{
i
}
.self_attn.v_proj.bias"
return
f
"model.layers.
{
i
}
.self_attn.v_proj.bias"
def
attn_q_norm
(
self
,
i
):
return
f
"model.layers.
{
i
}
.self_attn.q_norm.weight"
def
attn_k_norm
(
self
,
i
):
return
f
"model.layers.
{
i
}
.self_attn.k_norm.weight"
def
ffn_norm
(
self
,
i
):
def
ffn_norm
(
self
,
i
):
return
f
"model.layers.
{
i
}
.post_attention_layernorm.weight"
return
f
"model.layers.
{
i
}
.post_attention_layernorm.weight"
...
@@ -117,7 +123,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
...
@@ -117,7 +123,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
if
"num_key_value_heads"
in
config
if
"num_key_value_heads"
in
config
else
config
[
"num_attention_heads"
]
else
config
[
"num_attention_heads"
]
),
),
dh
=
config
[
"hidden_size"
]
//
config
[
"num_attention_heads"
],
dh
=
config
[
"head_dim"
]
if
"head_dim"
in
config
else
config
[
"hidden_size"
]
//
config
[
"num_attention_heads"
],
di
=
config
[
"intermediate_size"
],
di
=
config
[
"intermediate_size"
],
dctx
=
(
dctx
=
(
config
[
"max_position_embeddings"
]
if
max_tokens
is
None
else
max_tokens
config
[
"max_position_embeddings"
]
if
max_tokens
is
None
else
max_tokens
...
@@ -275,6 +281,35 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
...
@@ -275,6 +281,35 @@ class JiugeWeightsImpl(JiugeWeightsCStruct):
else
:
else
:
self
.
attn_qkv_b
=
None
self
.
attn_qkv_b
=
None
if
naming
.
attn_q_norm
(
0
)
in
state_dict
:
self
.
attn_q_norm_tensors
=
[
state_dict
[
naming
.
attn_q_norm
(
i
)]
.
reshape
([
2
,
dh
//
2
])
.
transpose
(
0
,
1
)
.
contiguous
()
.
to
(
torch_dt_norm
)
for
i
in
range
(
nlayer
)
]
self
.
attn_q_norm_ptrs
=
[
self
.
attn_q_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
self
.
attn_q_norm
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_q_norm_ptrs
)
self
.
attn_k_norm_tensors
=
[
state_dict
[
naming
.
attn_k_norm
(
i
)]
.
reshape
([
2
,
dh
//
2
])
.
transpose
(
0
,
1
)
.
contiguous
()
.
to
(
torch_dt_norm
)
for
i
in
range
(
nlayer
)
]
self
.
attn_k_norm_ptrs
=
[
self
.
attn_k_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
self
.
attn_k_norm
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_k_norm_ptrs
)
else
:
self
.
attn_q_norm
=
None
self
.
attn_k_norm
=
None
self
.
attn_o_tensor
=
[
self
.
attn_o_tensor
=
[
(
(
state_dict
[
naming
.
attn_o
(
i
)]
state_dict
[
naming
.
attn_o
(
i
)]
...
@@ -481,7 +516,7 @@ class JiugeForCauslLM:
...
@@ -481,7 +516,7 @@ class JiugeForCauslLM:
)
)
else
:
else
:
raise
ValueError
(
"Unsupported weight naming"
)
raise
ValueError
(
"Unsupported weight naming"
)
elif
"qwen2"
==
config
[
"model_type"
]:
elif
"qwen2"
==
config
[
"model_type"
]
or
"qwen3"
==
config
[
"model_type"
]:
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
,
max_tokens
=
max_tokens
)
self
.
meta
=
JiugeMetaFromLlama
(
config
,
max_tokens
=
max_tokens
)
...
@@ -498,6 +533,24 @@ class JiugeForCauslLM:
...
@@ -498,6 +533,24 @@ class JiugeForCauslLM:
else
:
else
:
raise
ValueError
(
"Unsupported model architecture"
)
raise
ValueError
(
"Unsupported model architecture"
)
if
"llama"
==
config
[
"model_type"
]:
from
tokenizers
import
decoders
as
_dec
backend
=
getattr
(
self
.
tokenizer
,
"backend_tokenizer"
,
None
)
target
=
getattr
(
backend
,
"_tokenizer"
,
backend
)
norm
=
getattr
(
target
,
"normalizer"
,
None
)
dec
=
getattr
(
target
,
"decoder"
,
None
)
sn
=
repr
(
norm
)[:
800
]
if
norm
is
not
None
else
""
sd
=
repr
(
dec
)[:
800
]
if
dec
is
not
None
else
""
has_prepend
=
"Prepend"
in
sn
has_strip
=
"Strip"
in
sd
if
has_prepend
and
has_strip
:
target
.
decoder
=
_dec
.
Sequence
([
_dec
.
Replace
(
"▁"
,
" "
),
_dec
.
ByteFallback
(),
_dec
.
Fuse
(),
])
load_end_time
=
time
.
time
()
load_end_time
=
time
.
time
()
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
print
(
f
"Time used:
{
load_end_time
-
load_start_time
:.
3
f
}
s"
)
...
@@ -574,11 +627,8 @@ class JiugeForCauslLM:
...
@@ -574,11 +627,8 @@ class JiugeForCauslLM:
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
steps
+=
1
output_str
=
(
output_str
=
self
.
tokenizer
.
decode
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output_content
+=
output_str
output_content
+=
output_str
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
...
...
scripts/jiuge_awq.py
View file @
80909bee
...
@@ -256,11 +256,6 @@ class JiugeAWQForCausalLM:
...
@@ -256,11 +256,6 @@ class JiugeAWQForCausalLM:
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
steps
+=
1
# output_str = (
# self.tokenizer._tokenizer.id_to_token(output_tokens[0])
# .replace("▁", " ")
# .replace("<0x0A>", "\n")
# )
output_str
=
self
.
tokenizer
.
decode
(
output_tokens
[
0
])
output_str
=
self
.
tokenizer
.
decode
(
output_tokens
[
0
])
output_content
+=
output_str
output_content
+=
output_str
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
...
...
scripts/launch_server.py
View file @
80909bee
...
@@ -226,11 +226,8 @@ async def chat_stream(id_, request_data, request: Request):
...
@@ -226,11 +226,8 @@ async def chat_stream(id_, request_data, request: Request):
break
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
content
=
request
.
app
.
state
.
model
.
tokenizer
.
decode
(
token
)
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
content
),
ensure_ascii
=
False
)
chunk
=
json
.
dumps
(
chunk_json
(
id_
,
content
=
content
),
ensure_ascii
=
False
)
yield
f
"data:
{
chunk
}
\n\n
"
yield
f
"data:
{
chunk
}
\n\n
"
...
@@ -255,11 +252,7 @@ async def chat(id_, request_data, request: Request):
...
@@ -255,11 +252,7 @@ async def chat(id_, request_data, request: Request):
break
break
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
token
=
await
infer_task
.
output_queue
.
async_q
.
get
()
content
=
(
content
=
request
.
app
.
state
.
model
.
tokenizer
.
decode
(
token
)
request
.
app
.
state
.
model
.
tokenizer
.
_tokenizer
.
id_to_token
(
token
)
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output
.
append
(
content
)
output
.
append
(
content
)
output_text
=
""
.
join
(
output
).
strip
()
output_text
=
""
.
join
(
output
).
strip
()
...
...
scripts/libinfinicore_infer/jiuge.py
View file @
80909bee
...
@@ -31,6 +31,8 @@ class JiugeWeightsCStruct(Structure):
...
@@ -31,6 +31,8 @@ class JiugeWeightsCStruct(Structure):
(
"attn_norm"
,
POINTER
(
c_void_p
)),
(
"attn_norm"
,
POINTER
(
c_void_p
)),
(
"attn_qkv"
,
POINTER
(
c_void_p
)),
(
"attn_qkv"
,
POINTER
(
c_void_p
)),
(
"attn_qkv_b"
,
POINTER
(
c_void_p
)),
(
"attn_qkv_b"
,
POINTER
(
c_void_p
)),
(
"attn_q_norm"
,
POINTER
(
c_void_p
)),
(
"attn_k_norm"
,
POINTER
(
c_void_p
)),
(
"attn_o"
,
POINTER
(
c_void_p
)),
(
"attn_o"
,
POINTER
(
c_void_p
)),
(
"ffn_norm"
,
POINTER
(
c_void_p
)),
(
"ffn_norm"
,
POINTER
(
c_void_p
)),
(
"ffn_gate_up"
,
POINTER
(
c_void_p
)),
(
"ffn_gate_up"
,
POINTER
(
c_void_p
)),
...
...
scripts/test_ceval.py
View file @
80909bee
...
@@ -43,11 +43,7 @@ class JiugeForCeval(JiugeForCauslLM):
...
@@ -43,11 +43,7 @@ class JiugeForCeval(JiugeForCauslLM):
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
output_tokens
=
self
.
batch_infer_one_round
([
infer_task
])
end_time
=
time
.
time
()
end_time
=
time
.
time
()
steps
+=
1
steps
+=
1
output_str
=
(
output_str
=
self
.
tokenizer
.
decode
(
output_tokens
[
0
])
self
.
tokenizer
.
_tokenizer
.
id_to_token
(
output_tokens
[
0
])
.
replace
(
"▁"
,
" "
)
.
replace
(
"<0x0A>"
,
"
\n
"
)
)
output_content
+=
output_str
output_content
+=
output_str
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
if
output_tokens
[
0
]
in
self
.
eos_token_id
:
...
...
src/models/jiuge/jiuge.cpp
View file @
80909bee
...
@@ -21,7 +21,7 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
...
@@ -21,7 +21,7 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
infinirtStream_t
stream
;
infinirtStream_t
stream
;
infinirtStreamCreate
(
&
stream
);
infinirtStreamCreate
(
&
stream
);
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
w_attn_norm
,
w_attn_qkv
,
b_attn_qkv
,
w_attn_out
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
w_attn_norm
,
w_attn_qkv
,
b_attn_qkv
,
w_attn_q_norm
,
w_attn_k_norm
,
w_attn_out
,
w_ffn_norm
,
w_ffn_gate_up
,
w_ffn_down
;
w_ffn_norm
,
w_ffn_gate_up
,
w_ffn_down
;
for
(
size_t
layer
=
0
;
layer
<
meta
->
nlayer
;
layer
++
)
{
for
(
size_t
layer
=
0
;
layer
<
meta
->
nlayer
;
layer
++
)
{
w_attn_norm
.
push_back
(
w_attn_norm
.
push_back
(
...
@@ -32,6 +32,13 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
...
@@ -32,6 +32,13 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
b_attn_qkv
.
push_back
(
b_attn_qkv
.
push_back
(
getAttnQKVBias
(
meta
,
weights
,
layer
,
idev
,
ndev
));
getAttnQKVBias
(
meta
,
weights
,
layer
,
idev
,
ndev
));
}
}
if
(
weights
->
attn_q_norm
!=
nullptr
)
{
w_attn_q_norm
.
push_back
(
getAttnQNorm
(
meta
,
weights
,
layer
));
w_attn_k_norm
.
push_back
(
getAttnKNorm
(
meta
,
weights
,
layer
));
}
w_attn_out
.
push_back
(
w_attn_out
.
push_back
(
getAttnO
(
meta
,
weights
,
layer
,
idev
,
ndev
));
getAttnO
(
meta
,
weights
,
layer
,
idev
,
ndev
));
w_ffn_norm
.
push_back
(
w_ffn_norm
.
push_back
(
...
@@ -56,6 +63,8 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
...
@@ -56,6 +63,8 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta,
w_attn_norm
,
w_attn_norm
,
w_attn_qkv
,
w_attn_qkv
,
b_attn_qkv
,
b_attn_qkv
,
w_attn_q_norm
,
w_attn_k_norm
,
w_attn_out
,
w_attn_out
,
w_ffn_norm
,
w_ffn_norm
,
w_ffn_gate_up
,
w_ffn_gate_up
,
...
@@ -130,6 +139,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
...
@@ -130,6 +139,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
auto
dvoc
=
meta
.
dvoc
;
auto
dvoc
=
meta
.
dvoc
;
auto
stream
=
rsrc
.
stream
;
auto
stream
=
rsrc
.
stream
;
bool
has_qkv_bias
=
rsrc
.
b_attn_qkv
.
size
()
>
0
;
bool
has_qkv_bias
=
rsrc
.
b_attn_qkv
.
size
()
>
0
;
bool
has_qk_norm
=
rsrc
.
w_attn_q_norm
.
size
()
>
0
&&
rsrc
.
w_attn_k_norm
.
size
()
>
0
;
// Allocate buffers
// Allocate buffers
auto
logits_in
=
Tensor
::
buffer
(
dt_logits
,
{
ntok
,
d
},
rsrc
.
memory_pool
);
auto
logits_in
=
Tensor
::
buffer
(
dt_logits
,
{
ntok
,
d
},
rsrc
.
memory_pool
);
...
@@ -142,6 +152,8 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
...
@@ -142,6 +152,8 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
qkv_rope
=
qkv_buf
->
view
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
auto
qkv_rope
=
qkv_buf
->
view
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
auto
q_buf
=
qkv_rope
->
slice
(
1
,
0
,
nh
);
auto
k_buf
=
qkv_rope
->
slice
(
1
,
nh
,
nkvh
);
// Prepare inputs
// Prepare inputs
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
...
@@ -198,9 +210,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
...
@@ -198,9 +210,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
// qkv_proj
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
nullptr
,
has_qkv_bias
?
rsrc
.
b_attn_qkv
[
layer
]
:
nullptr
);
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
nullptr
,
has_qkv_bias
?
rsrc
.
b_attn_qkv
[
layer
]
:
nullptr
);
if
(
has_qk_norm
)
{
rmsnorm
(
q_buf
,
q_buf
,
rsrc
.
w_attn_q_norm
[
layer
],
meta
.
epsilon
);
rmsnorm
(
k_buf
,
k_buf
,
rsrc
.
w_attn_k_norm
[
layer
],
meta
.
epsilon
);
}
// rope
// rope
rope
(
q
kv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
)
,
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
q
_buf
,
q_buf
,
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
)
,
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
k_buf
,
k_buf
,
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
size_t
token_offset
=
0
;
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
...
@@ -299,11 +317,11 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
...
@@ -299,11 +317,11 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
__C
void
__C
void
inferBatchJiuge
(
struct
JiugeModel
*
model
,
inferBatchJiuge
(
struct
JiugeModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
struct
KVCache
**
kv_caches
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
const
float
*
temperature
,
const
uint32_t
*
topk
,
const
float
*
topp
,
uint32_t
*
output
)
{
uint32_t
*
output
)
{
model
->
req
.
tokens
=
tokens
;
model
->
req
.
tokens
=
tokens
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
req_lens
=
req_lens
;
...
@@ -332,10 +350,10 @@ inferBatchJiuge(struct JiugeModel *model,
...
@@ -332,10 +350,10 @@ inferBatchJiuge(struct JiugeModel *model,
__C
void
__C
void
forwardBatchJiuge
(
struct
JiugeModel
*
model
,
forwardBatchJiuge
(
struct
JiugeModel
*
model
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
tokens
,
uint32_t
ntok
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
const
uint32_t
*
req_lens
,
uint32_t
nreq
,
const
uint32_t
*
req_pos
,
struct
KVCache
**
kv_caches
,
struct
KVCache
**
kv_caches
,
void
*
logits
)
{
void
*
logits
)
{
model
->
req
.
tokens
=
tokens
;
model
->
req
.
tokens
=
tokens
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
ntok
=
ntok
;
model
->
req
.
req_lens
=
req_lens
;
model
->
req
.
req_lens
=
req_lens
;
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
80909bee
...
@@ -20,7 +20,7 @@ struct JiugeDeviceResource {
...
@@ -20,7 +20,7 @@ struct JiugeDeviceResource {
// Weights
// Weights
std
::
shared_ptr
<
Tensor
>
w_in_embd
,
w_out_norm
,
w_out_embd
,
sin_table
,
std
::
shared_ptr
<
Tensor
>
w_in_embd
,
w_out_norm
,
w_out_embd
,
sin_table
,
cos_table
;
cos_table
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
w_attn_norm
,
w_attn_qkv
,
b_attn_qkv
,
w_attn_out
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
w_attn_norm
,
w_attn_qkv
,
b_attn_qkv
,
w_attn_q_norm
,
w_attn_k_norm
,
w_attn_out
,
w_ffn_norm
,
w_ffn_gate_up
,
w_ffn_down
;
w_ffn_norm
,
w_ffn_gate_up
,
w_ffn_down
;
// Streams
// Streams
infinirtStream_t
stream
;
infinirtStream_t
stream
;
...
...
src/models/jiuge/jiuge_weight.hpp
View file @
80909bee
...
@@ -70,6 +70,22 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
...
@@ -70,6 +70,22 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
return
Tensor
::
weight
((
char
*
)(
w
->
attn_qkv_b
[
layer
])
+
offset
,
w
->
dt_mat
,
shape
);
return
Tensor
::
weight
((
char
*
)(
w
->
attn_qkv_b
[
layer
])
+
offset
,
w
->
dt_mat
,
shape
);
}
}
inline
std
::
shared_ptr
<
Tensor
>
getAttnQNorm
(
JiugeMeta
const
*
meta
,
JiugeWeights
const
*
w
,
size_t
layer
)
{
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dh
});
return
Tensor
::
weight
((
char
*
)(
w
->
attn_q_norm
[
layer
]),
w
->
dt_norm
,
shape
);
}
inline
std
::
shared_ptr
<
Tensor
>
getAttnKNorm
(
JiugeMeta
const
*
meta
,
JiugeWeights
const
*
w
,
size_t
layer
)
{
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dh
});
return
Tensor
::
weight
((
char
*
)(
w
->
attn_k_norm
[
layer
]),
w
->
dt_norm
,
shape
);
}
inline
std
::
shared_ptr
<
Tensor
>
getAttnO
(
JiugeMeta
const
*
meta
,
inline
std
::
shared_ptr
<
Tensor
>
getAttnO
(
JiugeMeta
const
*
meta
,
JiugeWeights
const
*
w
,
size_t
layer
,
JiugeWeights
const
*
w
,
size_t
layer
,
size_t
idev
,
size_t
ndev
)
{
size_t
idev
,
size_t
ndev
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment