Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
366d3aef
"vscode:/vscode.git/clone" did not exist on "91c3669f21cf569f13bd15569b5411e6ca9a8961"
Commit
366d3aef
authored
May 22, 2025
by
Pan Zezhong
Browse files
support 9g7b
parent
c3d5efa5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
95 additions
and
42 deletions
+95
-42
scripts/jiuge.py
scripts/jiuge.py
+92
-39
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+2
-2
src/models/jiuge/jiuge_weight.hpp
src/models/jiuge/jiuge_weight.hpp
+1
-1
No files found.
scripts/jiuge.py
View file @
366d3aef
from
ctypes
import
POINTER
,
c_int
,
c_uint
,
c_void_p
,
byref
from
ctypes
import
POINTER
,
c_int
,
c_uint
,
c_void_p
,
byref
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
safetensors
import
safetensors
import
sys
import
sys
import
time
import
time
import
json
from
libinfinicore_infer
import
(
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeMeta
,
...
@@ -18,6 +20,7 @@ from libinfinicore_infer import (
...
@@ -18,6 +20,7 @@ from libinfinicore_infer import (
import
torch
import
torch
import
transformers
import
transformers
torch
.
set_default_device
(
"cpu"
)
class
LlamaWeightsNaming
:
class
LlamaWeightsNaming
:
def
input_embd
(
self
):
def
input_embd
(
self
):
...
@@ -73,8 +76,8 @@ class LlamaWeightsNaming:
...
@@ -73,8 +76,8 @@ class LlamaWeightsNaming:
class
JiugeMetaFromLlama
(
JiugeMeta
):
class
JiugeMetaFromLlama
(
JiugeMeta
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
def
__init__
(
self
,
config
,
dtype
=
torch
.
float16
):
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
dt_
=
DataType
.
INFINI_DTYPE_F16
dt_
=
DataType
.
INFINI_DTYPE_F16
elif
dtype
==
torch
.
float32
:
elif
dtype
==
torch
.
float32
:
dt_
=
DataType
.
INFINI_DTYPE_F32
dt_
=
DataType
.
INFINI_DTYPE_F32
...
@@ -82,27 +85,35 @@ class JiugeMetaFromLlama(JiugeMeta):
...
@@ -82,27 +85,35 @@ class JiugeMetaFromLlama(JiugeMeta):
dt_
=
DataType
.
INFINI_DTYPE_F16
dt_
=
DataType
.
INFINI_DTYPE_F16
super
().
__init__
(
super
().
__init__
(
dt_logits
=
dt_
,
dt_logits
=
dt_
,
nlayer
=
config
.
num_hidden_layers
,
nlayer
=
config
[
"
num_hidden_layers
"
]
,
d
=
config
.
hidden_size
,
d
=
config
[
"
hidden_size
"
]
,
nh
=
config
.
num_attention_heads
,
nh
=
config
[
"
num_attention_heads
"
]
,
nkvh
=
(
nkvh
=
(
config
.
num_key_value_heads
config
[
"
num_key_value_heads
"
]
if
config
.
num_key_value_heads
if
"
num_key_value_heads
"
in
config
else
config
.
num_attention_heads
else
config
[
"
num_attention_heads
"
]
),
),
dh
=
config
.
hidden_size
//
config
.
num_attention_heads
,
dh
=
config
[
"
hidden_size
"
]
//
config
[
"
num_attention_heads
"
]
,
di
=
config
.
intermediate_size
,
di
=
config
[
"
intermediate_size
"
]
,
dctx
=
config
.
max_position_embeddings
,
dctx
=
config
[
"
max_position_embeddings
"
]
,
dvoc
=
config
.
vocab_size
,
dvoc
=
config
[
"
vocab_size
"
]
,
epsilon
=
config
.
rms_norm_eps
,
epsilon
=
config
[
"
rms_norm_eps
"
]
,
theta
=
config
.
rope_theta
,
theta
=
(
config
[
"
rope_theta
"
]
if
"rope_theta"
in
config
else
100000.0
)
,
end_token
=
2
,
end_token
=
2
,
)
)
self
.
torch_dtype_logits
=
dtype
self
.
torch_dtype_logits
=
dtype
class
JiugeWeightsImpl
(
JiugeWeights
):
class
JiugeWeightsImpl
(
JiugeWeights
):
def
__init__
(
self
,
meta
,
naming
,
state_dict
,
torch_dt_mat
=
torch
.
float16
,
torch_dt_norm
=
torch
.
float32
,
ndev
=
1
):
def
__init__
(
self
,
meta
,
naming
,
state_dict
,
torch_dt_mat
=
torch
.
float16
,
torch_dt_norm
=
torch
.
float32
,
ndev
=
1
,
):
nlayer
=
meta
.
nlayer
nlayer
=
meta
.
nlayer
nh
=
meta
.
nh
nh
=
meta
.
nh
nkvh
=
meta
.
nkvh
nkvh
=
meta
.
nkvh
...
@@ -127,12 +138,23 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -127,12 +138,23 @@ class JiugeWeightsImpl(JiugeWeights):
else
:
else
:
raise
ValueError
(
"Unsupported norm weight data type"
)
raise
ValueError
(
"Unsupported norm weight data type"
)
input_embd_naming
=
(
naming
.
input_embd
()
if
naming
.
input_embd
()
in
state_dict
else
naming
.
output_embd
()
)
output_embd_naming
=
(
naming
.
output_embd
()
if
naming
.
output_embd
()
in
state_dict
else
naming
.
input_embd
()
)
self
.
nlayer
=
nlayer
self
.
nlayer
=
nlayer
self
.
input_embd_tensor
=
state_dict
[
naming
.
input_embd
()
].
to
(
torch_dt_logits
)
self
.
input_embd_tensor
=
state_dict
[
input_embd
_naming
].
to
(
torch_dt_logits
)
self
.
input_embd
=
self
.
input_embd_tensor
.
data_ptr
()
self
.
input_embd
=
self
.
input_embd_tensor
.
data_ptr
()
self
.
output_norm_tensor
=
state_dict
[
naming
.
output_norm
()].
to
(
torch_dt_norm
)
self
.
output_norm_tensor
=
state_dict
[
naming
.
output_norm
()].
to
(
torch_dt_norm
)
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
naming
.
output_embd
()
].
to
(
torch_dt_mat
)
self
.
output_embd_tensor
=
state_dict
[
output_embd
_naming
].
to
(
torch_dt_mat
)
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
attn_norm_tensors
=
[
self
.
attn_norm_tensors
=
[
...
@@ -164,7 +186,9 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -164,7 +186,9 @@ class JiugeWeightsImpl(JiugeWeights):
_result
.
append
(
_V
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
_result
.
append
(
_V
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
return
_result
return
_result
self
.
qkv_tensor
=
[
torch
.
concat
(
qkv_slices
(
i
)).
to
(
torch_dt_mat
)
for
i
in
range
(
nlayer
)]
self
.
qkv_tensor
=
[
torch
.
concat
(
qkv_slices
(
i
)).
to
(
torch_dt_mat
)
for
i
in
range
(
nlayer
)
]
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
...
@@ -184,13 +208,15 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -184,13 +208,15 @@ class JiugeWeightsImpl(JiugeWeights):
_nh
=
nh
//
ndev
_nh
=
nh
//
ndev
_nkvh
=
nkvh
//
ndev
_nkvh
=
nkvh
//
ndev
for
_idev
in
range
(
ndev
):
for
_idev
in
range
(
ndev
):
_result
.
append
(
_QB
[
_idev
*
_nh
:
(
_idev
+
1
)
*
_nh
,
:,
:])
_result
.
append
(
_QB
[
_idev
*
_nh
:
(
_idev
+
1
)
*
_nh
,
:,
:]
.
flatten
()
)
_result
.
append
(
_KB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
_result
.
append
(
_KB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:]
.
flatten
()
)
_result
.
append
(
_VB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
_result
.
append
(
_VB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:]
.
flatten
()
)
return
_result
return
_result
if
naming
.
attn_q_b
(
0
)
in
state_dict
:
if
naming
.
attn_q_b
(
0
)
in
state_dict
:
self
.
qkv_b_tensors
=
[
torch
.
concat
(
qkv_b_slices
(
i
)).
to
(
torch_dt_logits
)
for
i
in
range
(
nlayer
)]
self
.
qkv_b_tensors
=
[
torch
.
concat
(
qkv_b_slices
(
i
)).
to
(
torch_dt_logits
)
for
i
in
range
(
nlayer
)
]
self
.
qkv_b_tensor_ptrs
=
[
self
.
qkv_b_tensor_ptrs
=
[
self
.
qkv_b_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
self
.
qkv_b_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
]
...
@@ -199,7 +225,8 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -199,7 +225,8 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
attn_qkv_b
=
None
self
.
attn_qkv_b
=
None
self
.
attn_o_tensor
=
[
self
.
attn_o_tensor
=
[
state_dict
[
naming
.
attn_o
(
i
)].
to
(
torch_dt_mat
)
state_dict
[
naming
.
attn_o
(
i
)]
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
contiguous
()
...
@@ -208,7 +235,9 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -208,7 +235,9 @@ class JiugeWeightsImpl(JiugeWeights):
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_o_ptrs
)
self
.
attn_o
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_o_ptrs
)
self
.
ffn_norm_tensors
=
[
state_dict
[
naming
.
ffn_norm
(
i
)].
to
(
torch_dt_norm
)
for
i
in
range
(
nlayer
)]
self
.
ffn_norm_tensors
=
[
state_dict
[
naming
.
ffn_norm
(
i
)].
to
(
torch_dt_norm
)
for
i
in
range
(
nlayer
)
]
self
.
ffn_norm_ptrs
=
[
self
.
ffn_norm_ptrs
=
[
self
.
ffn_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
self
.
ffn_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
]
...
@@ -224,12 +253,15 @@ class JiugeWeightsImpl(JiugeWeights):
...
@@ -224,12 +253,15 @@ class JiugeWeightsImpl(JiugeWeights):
_result
.
append
(
state_dict
[
naming
.
up
(
_i
)][
_start
:
_end
,
:])
_result
.
append
(
state_dict
[
naming
.
up
(
_i
)][
_start
:
_end
,
:])
return
_result
return
_result
self
.
gate_up_tensors
=
[
torch
.
concat
(
gate_up_slices
(
i
)).
to
(
torch_dt_mat
)
for
i
in
range
(
nlayer
)]
self
.
gate_up_tensors
=
[
torch
.
concat
(
gate_up_slices
(
i
)).
to
(
torch_dt_mat
)
for
i
in
range
(
nlayer
)
]
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_down_tensor
=
[
self
.
ffn_down_tensor
=
[
state_dict
[
naming
.
down
(
i
)].
to
(
torch_dt_mat
)
state_dict
[
naming
.
down
(
i
)]
.
to
(
torch_dt_mat
)
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
reshape
([
d
,
ndev
,
di
//
ndev
])
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
contiguous
()
...
@@ -250,17 +282,17 @@ class JiugeForCauslLM:
...
@@ -250,17 +282,17 @@ class JiugeForCauslLM:
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
)
return
tensors_
return
tensors_
config
=
transformers
.
AutoConfig
.
from_pretrained
(
with
open
(
os
.
path
.
join
(
model_dir_path
,
"config.json"
),
"r"
)
as
f
:
model_dir_path
,
trust_remote_code
=
True
config
=
json
.
load
(
f
)
)
if
"llama"
==
config
.
model_type
:
if
"llama"
==
config
[
"
model_type
"
]
:
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
half
()
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
cpu
().
half
()
self
.
meta
=
JiugeMetaFromLlama
(
model
.
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
)
)
elif
"fm9g"
==
config
.
model_type
:
elif
"fm9g"
==
config
[
"
model_type
"
]
:
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
meta
=
JiugeMetaFromLlama
(
config
)
...
@@ -270,6 +302,19 @@ class JiugeForCauslLM:
...
@@ -270,6 +302,19 @@ class JiugeForCauslLM:
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
model_dir_path
,
trust_remote_code
=
True
)
)
elif
"fm9g7b"
==
config
[
"model_type"
]:
state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_dir_path
,
"pytorch_model.bin"
),
weights_only
=
True
,
map_location
=
"cpu"
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
)
else
:
else
:
raise
ValueError
(
"Unsupported model architecture"
)
raise
ValueError
(
"Unsupported model architecture"
)
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
...
@@ -285,6 +330,11 @@ class JiugeForCauslLM:
...
@@ -285,6 +330,11 @@ class JiugeForCauslLM:
pass
pass
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
def
generate
(
self
,
input_content
,
max_steps
,
topp
=
1.0
,
topk
=
1
,
temperature
=
1.0
):
input_content
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
[{
"role"
:
"user"
,
"content"
:
input_content
}],
add_generation_prompt
=
True
,
tokenize
=
False
,
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
print
(
input_content
,
end
=
""
,
flush
=
True
)
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
kv_cache
=
create_kv_cache
(
self
.
model_instance
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
tokens
=
self
.
tokenizer
.
encode
(
input_content
)
...
@@ -298,8 +348,10 @@ class JiugeForCauslLM:
...
@@ -298,8 +348,10 @@ class JiugeForCauslLM:
ans
=
(
c_uint
*
nreq
)()
ans
=
(
c_uint
*
nreq
)()
steps
=
0
steps
=
0
start_time
=
time
.
time
()
total_time
=
0
for
_
in
range
(
max_steps
):
for
step_i
in
range
(
max_steps
):
start_time
=
time
.
time
()
infer_batch
(
infer_batch
(
self
.
model_instance
,
self
.
model_instance
,
tokens
,
tokens
,
...
@@ -324,15 +376,16 @@ class JiugeForCauslLM:
...
@@ -324,15 +376,16 @@ class JiugeForCauslLM:
break
break
output_content
+=
output_str
output_content
+=
output_str
print
(
output_str
,
end
=
""
,
flush
=
True
)
print
(
output_str
,
end
=
""
,
flush
=
True
)
# print(output_tokens[0])
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
req_pos
[
0
]
=
req_pos
[
0
]
+
ntok
ntok
=
1
ntok
=
1
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
tokens
=
(
c_uint
*
ntok
)(
*
output_tokens
)
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
req_lens
=
(
c_uint
*
nreq
)(
*
[
ntok
])
end_time
=
time
.
time
()
if
step_i
>
0
:
total_time
+=
end_time
-
start_time
print
(
"
\n
"
)
print
(
"
\n
"
)
end_time
=
time
.
time
()
avg_time
=
total_time
*
1000
/
(
steps
-
1
)
avg_time
=
(
end_time
-
start_time
)
*
1000
/
steps
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
print
(
f
"Time per step:
{
avg_time
:.
3
f
}
ms"
)
for
kv_cache
in
kv_caches
:
for
kv_cache
in
kv_caches
:
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
drop_kv_cache
(
self
.
model_instance
,
kv_cache
)
...
@@ -367,7 +420,7 @@ def test():
...
@@ -367,7 +420,7 @@ def test():
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
model
.
generate
(
"
Once upon a time,
"
,
1
00
)
model
.
generate
(
"
山东最高的山是?
"
,
5
00
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
src/models/jiuge/jiuge.cpp
View file @
366d3aef
...
@@ -243,12 +243,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -243,12 +243,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
if
(
has_qkv_bias
)
{
if
(
has_qkv_bias
)
{
RUN_INFINI
(
infiniopRearrange
(
RUN_INFINI
(
infiniopRearrange
(
desc_qkv_bias
,
desc_qkv_bias
,
qkv_buf
->
data
(),
rsrc
.
b_attn_qkv
.
data
(),
stream
));
qkv_buf
->
data
(),
rsrc
.
b_attn_qkv
[
layer
]
->
data
(),
stream
));
}
}
RUN_INFINI
(
infiniopGemm
(
RUN_INFINI
(
infiniopGemm
(
desc_attn_qkv
,
workspace
,
workspace_size
,
desc_attn_qkv
,
workspace
,
workspace_size
,
qkv_buf
->
data
(),
logits_out
->
data
(),
qkv_buf
->
data
(),
logits_out
->
data
(),
rsrc
.
w_attn_qkv
[
layer
]
->
data
(),
1.0
,
0.0
,
stream
));
rsrc
.
w_attn_qkv
[
layer
]
->
data
(),
1.0
,
has_qkv_bias
?
1.0
:
0.0
,
stream
));
// rope
// rope
RUN_INFINI
(
infiniopRoPE
(
RUN_INFINI
(
infiniopRoPE
(
desc_rope_q
,
workspace
,
workspace_size
,
desc_rope_q
,
workspace
,
workspace_size
,
...
...
src/models/jiuge/jiuge_weight.hpp
View file @
366d3aef
...
@@ -56,7 +56,7 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
...
@@ -56,7 +56,7 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
auto
nh
=
meta
->
nh
;
auto
nh
=
meta
->
nh
;
auto
dh
=
meta
->
dh
;
auto
dh
=
meta
->
dh
;
size_t
offset
=
idev
*
((
nkvh
*
2
+
nh
)
/
ndev
*
dh
)
*
dsize
(
w
->
dt_mat
);
size_t
offset
=
idev
*
((
nkvh
*
2
+
nh
)
/
ndev
*
dh
)
*
dsize
(
w
->
dt_mat
);
auto
shape
=
std
::
vector
<
size_t
>
({
1
,
(
nh
+
2
*
nkvh
)
/
ndev
*
dh
});
auto
shape
=
std
::
vector
<
size_t
>
({(
nh
+
2
*
nkvh
)
/
ndev
*
dh
});
return
Tensor
::
weight
((
char
*
)(
w
->
attn_qkv_b
[
layer
])
+
offset
,
w
->
dt_mat
,
shape
);
return
Tensor
::
weight
((
char
*
)(
w
->
attn_qkv_b
[
layer
])
+
offset
,
w
->
dt_mat
,
shape
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment