Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
967bcb64
Commit
967bcb64
authored
May 19, 2025
by
PanZezhong
Browse files
llama cpu
parent
760b769e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
227 additions
and
124 deletions
+227
-124
scripts/jiuge.py
scripts/jiuge.py
+109
-33
scripts/libinfinicore_infer.py
scripts/libinfinicore_infer.py
+10
-10
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+20
-15
src/models/jiuge/jiuge_impl.hpp
src/models/jiuge/jiuge_impl.hpp
+1
-1
src/models/jiuge/jiuge_weight.hpp
src/models/jiuge/jiuge_weight.hpp
+8
-8
src/tensor.hpp
src/tensor.hpp
+2
-3
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+47
-49
src/tensor/transform.cpp
src/tensor/transform.cpp
+3
-5
src/utils.hpp
src/utils.hpp
+27
-0
No files found.
scripts/jiuge.py
View file @
967bcb64
from
ctypes
import
POINTER
,
c_uint
,
c_void_p
,
byref
from
ctypes
import
POINTER
,
c_int
,
c_uint
,
c_void_p
,
byref
from
pathlib
import
Path
import
safetensors
import
sys
import
time
from
libinfinicore_infer
import
(
JiugeMeta
,
JiugeWeights
,
...
...
@@ -50,6 +53,9 @@ class LlamaWeightsNaming:
def
attn_v_b
(
self
,
i
):
return
f
"model.layers.
{
i
}
.self_attn.v_proj.bias"
def
ffn_norm
(
self
,
i
):
return
f
"model.layers.
{
i
}
.post_attention_layernorm.weight"
def
gate
(
self
,
i
):
return
f
"model.layers.
{
i
}
.mlp.gate_proj.weight"
...
...
@@ -59,6 +65,12 @@ class LlamaWeightsNaming:
def
down
(
self
,
i
):
return
f
"model.layers.
{
i
}
.mlp.down_proj.weight"
def
match
(
state_dict
):
return
(
"model.norm.weight"
in
state_dict
and
"model.layers.0.self_attn.q_proj.weight"
in
state_dict
)
class
JiugeMetaFromLlama
(
JiugeMeta
):
def
__init__
(
self
,
config
,
infini_dtype
):
...
...
@@ -96,12 +108,22 @@ class JiugeWeightsImpl(JiugeWeights):
assert
nh
%
ndev
==
0
assert
nkvh
%
ndev
==
0
assert
di
%
ndev
==
0
self
.
input_embd
=
state_dict
[
naming
.
input_embd
()].
data_ptr
()
self
.
output_norm
=
state_dict
[
naming
.
output_norm
()].
data_ptr
()
self
.
output_embd
=
state_dict
[
naming
.
output_embd
()].
data_ptr
()
self
.
attn_norm
=
(
c_void_p
*
nlayer
)(
*
[
state_dict
[
naming
.
attn_norm
(
i
)].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
nlayer
=
nlayer
self
.
input_embd_tensor
=
state_dict
[
naming
.
input_embd
()]
self
.
input_embd
=
self
.
input_embd_tensor
.
data_ptr
()
self
.
output_norm_tensor
=
state_dict
[
naming
.
output_norm
()]
self
.
output_norm
=
self
.
output_norm_tensor
.
data_ptr
()
self
.
output_embd_tensor
=
state_dict
[
naming
.
output_embd
()]
self
.
output_embd
=
self
.
output_embd_tensor
.
data_ptr
()
self
.
attn_norm_tensors
=
[
state_dict
[
naming
.
attn_norm
(
i
)]
for
i
in
range
(
nlayer
)
]
self
.
attn_norm_ptrs
=
[
self
.
attn_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
self
.
attn_norm
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_norm_ptrs
)
def
qkv_slices
(
_i
):
_Q
=
(
...
...
@@ -125,9 +147,39 @@ class JiugeWeightsImpl(JiugeWeights):
return
_result
self
.
qkv_tensor
=
[
torch
.
concat
(
qkv_slices
(
i
))
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
qkv_tensor_ptrs
=
[
self
.
qkv_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_qkv
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_tensor_ptrs
)
def
qkv_b_slices
(
_i
):
_QB
=
(
state_dict
[
naming
.
attn_q_b
(
_i
)]
.
reshape
([
nh
,
2
,
dh
//
2
])
.
transpose
(
1
,
2
)
)
_KB
=
(
state_dict
[
naming
.
attn_k_b
(
_i
)]
.
reshape
([
nkvh
,
2
,
dh
//
2
])
.
transpose
(
1
,
2
)
)
_VB
=
state_dict
[
naming
.
attn_v_b
(
_i
)].
reshape
([
nkvh
,
dh
//
2
,
2
])
_result
=
[]
_nh
=
nh
//
ndev
_nkvh
=
nkvh
//
ndev
for
_idev
in
range
(
ndev
):
_result
.
append
(
_QB
[
_idev
*
_nh
:
(
_idev
+
1
)
*
_nh
,
:,
:])
_result
.
append
(
_KB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
_result
.
append
(
_VB
[
_idev
*
_nkvh
:
(
_idev
+
1
)
*
_nkvh
,
:,
:])
return
_result
if
naming
.
attn_q_b
(
0
)
in
state_dict
:
self
.
qkv_b_tensors
=
[
torch
.
concat
(
qkv_b_slices
(
i
))
for
i
in
range
(
nlayer
)]
self
.
qkv_b_tensor_ptrs
=
[
self
.
qkv_b_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
self
.
attn_qkv_b
=
(
c_void_p
*
nlayer
)(
*
self
.
qkv_b_tensor_ptrs
)
else
:
self
.
attn_qkv_b
=
None
self
.
attn_o_tensor
=
[
state_dict
[
naming
.
attn_o
(
i
)]
.
reshape
([
d
,
ndev
,
nh
//
ndev
*
dh
])
...
...
@@ -135,12 +187,14 @@ class JiugeWeightsImpl(JiugeWeights):
.
contiguous
()
for
i
in
range
(
nlayer
)
]
self
.
attn_o
=
(
c_void_p
*
nlayer
)(
*
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
ffn_norm
=
(
c_void_p
*
nlayer
)(
*
[
state_dict
[
naming
.
ffn_norm
(
i
)].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
attn_o_ptrs
=
[
self
.
attn_o_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
attn_o
=
(
c_void_p
*
nlayer
)(
*
self
.
attn_o_ptrs
)
self
.
ffn_norm_tensors
=
[
state_dict
[
naming
.
ffn_norm
(
i
)]
for
i
in
range
(
nlayer
)]
self
.
ffn_norm_ptrs
=
[
self
.
ffn_norm_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)
]
self
.
ffn_norm
=
(
c_void_p
*
nlayer
)(
*
self
.
ffn_norm_ptrs
)
def
gate_up_slices
(
_i
):
_result
=
[]
...
...
@@ -152,11 +206,9 @@ class JiugeWeightsImpl(JiugeWeights):
_result
.
append
(
state_dict
[
naming
.
up
(
_i
)][
_start
:
_end
,
:])
return
_result
self
.
gate_up_tensor
=
[
torch
.
concat
(
gate_up_slices
(
i
))
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
[
self
.
gate_up_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
gate_up_tensors
=
[
torch
.
concat
(
gate_up_slices
(
i
))
for
i
in
range
(
nlayer
)]
self
.
gate_up_ptrs
=
[
self
.
gate_up_tensors
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_gate_up
=
(
c_void_p
*
nlayer
)(
*
self
.
gate_up_ptrs
)
self
.
ffn_down_tensor
=
[
state_dict
[
naming
.
down
(
i
)]
...
...
@@ -165,22 +217,46 @@ class JiugeWeightsImpl(JiugeWeights):
.
contiguous
()
for
i
in
range
(
nlayer
)
]
self
.
ffn_down
=
(
c_void_p
*
nlayer
)(
*
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
)
self
.
ffn_down_ptrs
=
[
self
.
ffn_down_tensor
[
i
].
data_ptr
()
for
i
in
range
(
nlayer
)]
self
.
ffn_down
=
(
c_void_p
*
nlayer
)(
*
self
.
ffn_down_ptrs
)
class
JiugeForCauslLM
:
def
__init__
(
self
,
model_dir_path
,
device
=
DeviceType
.
DEVICE_TYPE_CPU
,
ndev
=
1
):
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
,
torch_dtype
=
torch
.
float16
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
meta
=
JiugeMetaFromLlama
(
model
.
config
,
DataType
.
INFINI_DTYPE_F16
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
def
load_all_safetensors_from_dir
(
dir_path_
:
str
,
torch_type
=
torch
.
float16
):
tensors_
=
{}
dir_path_
=
Path
(
dir_path_
)
for
file
in
sorted
(
dir_path_
.
glob
(
"*.safetensors"
)):
data_
=
safetensors
.
safe_open
(
file
,
"pt"
)
for
name_
in
data_
.
keys
():
tensors_
[
name_
]
=
data_
.
get_tensor
(
name_
).
to
(
torch_type
)
return
tensors_
config
=
transformers
.
AutoConfig
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
)
dev_ids
=
(
c_uint
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
if
"llama"
==
config
.
model_type
:
model
=
transformers
.
LlamaForCausalLM
.
from_pretrained
(
model_dir_path
).
to
(
torch
.
float16
)
self
.
meta
=
JiugeMetaFromLlama
(
model
.
config
,
DataType
.
INFINI_DTYPE_F16
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
model
.
state_dict
(),
ndev
=
ndev
)
elif
"fm9g"
==
config
.
model_type
:
state_dict
=
load_all_safetensors_from_dir
(
model_dir_path
)
if
LlamaWeightsNaming
.
match
(
state_dict
):
self
.
meta
=
JiugeMetaFromLlama
(
config
,
DataType
.
INFINI_DTYPE_F16
)
self
.
weights
=
JiugeWeightsImpl
(
self
.
meta
,
LlamaWeightsNaming
(),
state_dict
,
ndev
=
ndev
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_dir_path
,
trust_remote_code
=
True
)
else
:
raise
ValueError
(
"Unsupported model architecture"
)
dev_ids
=
(
c_int
*
ndev
)(
*
[
i
for
i
in
range
(
ndev
)])
self
.
model_instance
=
create_jiuge_model
(
byref
(
self
.
meta
),
byref
(
self
.
weights
),
...
...
@@ -274,7 +350,7 @@ def test():
ndev
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
1
model
=
JiugeForCauslLM
(
model_path
,
device_type
,
ndev
)
model
.
generate
(
"
<用户>讲个长故事<AI>
"
,
5
00
)
model
.
generate
(
"
Once upon a time,
"
,
1
00
)
if
__name__
==
"__main__"
:
...
...
scripts/libinfinicore_infer.py
View file @
967bcb64
import
ctypes
from
ctypes
import
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
from
ctypes
import
c_size_t
,
c_uint
,
c_int
,
c_float
,
c_void_p
,
POINTER
import
os
...
...
@@ -40,14 +40,14 @@ class JiugeMeta(ctypes.Structure):
(
"dt_logits"
,
DataType
),
(
"dt_norm"
,
DataType
),
(
"dt_mat"
,
DataType
),
(
"nlayer"
,
c_
uin
t
),
(
"d"
,
c_
uin
t
),
(
"nh"
,
c_
uin
t
),
(
"nkvh"
,
c_
uin
t
),
(
"dh"
,
c_
uin
t
),
(
"di"
,
c_
uin
t
),
(
"dctx"
,
c_
uin
t
),
(
"dvoc"
,
c_
uin
t
),
(
"nlayer"
,
c_
size_
t
),
(
"d"
,
c_
size_
t
),
(
"nh"
,
c_
size_
t
),
(
"nkvh"
,
c_
size_
t
),
(
"dh"
,
c_
size_
t
),
(
"di"
,
c_
size_
t
),
(
"dctx"
,
c_
size_
t
),
(
"dvoc"
,
c_
size_
t
),
(
"epsilon"
,
c_float
),
(
"theta"
,
c_float
),
(
"end_token"
,
c_uint
),
...
...
@@ -57,7 +57,7 @@ class JiugeMeta(ctypes.Structure):
# Define the JiugeWeights struct
class
JiugeWeights
(
ctypes
.
Structure
):
_fields_
=
[
(
"nlayer"
,
c_
uin
t
),
(
"nlayer"
,
c_
size_
t
),
(
"input_embd"
,
c_void_p
),
(
"output_norm"
,
c_void_p
),
(
"output_embd"
,
c_void_p
),
...
...
src/models/jiuge/jiuge.cpp
View file @
967bcb64
...
...
@@ -188,12 +188,12 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
up_buf
=
gate_up_buf
->
slice
(
1
,
di
,
di
);
RUN_INFINI
(
infiniopCreateSwiGLUDescriptor
(
rsrc
.
handle
,
&
desc_swiglu
,
logits_out
->
desc
()
->
get
(),
up_buf
->
desc
()
->
get
(),
gate_buf
->
desc
()
->
get
()));
rsrc
.
handle
,
&
desc_swiglu
,
gate_buf
->
desc
()
->
get
(),
up_buf
->
desc
()
->
get
(),
gate_buf
->
desc
()
->
get
()));
RUN_INFINI
(
infiniopGetSwiGLUWorkspaceSize
(
desc_swiglu
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
RUN_INFINI
(
infiniopCreateGemmDescriptor
(
rsrc
.
handle
,
&
desc_ffn_down
,
logits_in
->
desc
()
->
get
(),
logits_out
->
desc
()
->
get
(),
rsrc
.
w_ffn_down
[
0
]
->
desc
()
->
get
()));
gate_buf
->
desc
()
->
get
(),
rsrc
.
w_ffn_down
[
0
]
->
desc
()
->
get
()));
RUN_INFINI
(
infiniopGetGemmWorkspaceSize
(
desc_ffn_down
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
...
...
@@ -215,7 +215,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
infiniopRandomSampleDescriptor_t
desc_sample
;
RUN_INFINI
(
infiniopCreateRandomSampleDescriptor
(
rsrc
.
handle
,
&
desc_sample
,
TensorDesc
::
create
(
INFINI_DTYPE_U64
,
{
1
},
{
1
})
->
get
(),
TensorDesc
::
create
(
INFINI_DTYPE_U64
,
{},
{})
->
get
(),
TensorDesc
::
create
(
dt_logits
,
{
dvoc
},
{
1
})
->
get
()));
RUN_INFINI
(
infiniopGetRandomSampleWorkspaceSize
(
desc_sample
,
&
temp_size
));
workspace_size
=
std
::
max
(
workspace_size
,
temp_size
);
...
...
@@ -276,24 +276,22 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
logits_in
->
data
(),
logits_in
->
data
(),
ntok
*
d
,
dt_logits
,
INFINICCL_SUM
,
rsrc
.
comm
,
stream
));
}
// 2. FFN
// rms_norm
RUN_INFINI
(
infiniopRMSNorm
(
desc_norm
,
workspace
,
workspace_size
,
logits_out
->
data
(),
logits_in
->
data
(),
rsrc
.
w_ffn_norm
[
layer
]
->
data
(),
stream
));
// mlp
RUN_INFINI
(
infiniopGemm
(
desc_ffn_gate_up
,
workspace
,
workspace_size
,
gate_up_buf
->
data
(),
logits_out
->
data
(),
rsrc
.
w_ffn_gate_up
[
layer
]
->
data
(),
1.0
,
0.0
,
stream
));
RUN_INFINI
(
infiniopSwiGLU
(
desc_swiglu
,
workspace
,
workspace_size
,
logits_out
->
data
(),
up_buf
->
data
(),
gate_buf
->
data
(),
stream
));
gate_buf
->
data
(),
up_buf
->
data
(),
gate_buf
->
data
(),
stream
));
RUN_INFINI
(
infiniopGemm
(
desc_ffn_down
,
workspace
,
workspace_size
,
logits_in
->
data
(),
logits_out
->
data
(),
logits_in
->
data
(),
gate_buf
->
data
(),
rsrc
.
w_ffn_down
[
layer
]
->
data
(),
1.0
,
idev
==
0
?
1.0
:
0.0
,
stream
));
// only rank 0 adds residual
// All_reduce if distributed
...
...
@@ -304,7 +302,6 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
}
}
// Sample and Output
uint64_t
tmp
;
if
(
idev
==
0
)
{
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
...
...
@@ -334,10 +331,10 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
token_offset
+=
seq_len
;
}
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
RUN_INFINI
(
infinirtMemcpy
(
&
tmp
,
result_buf
->
data
(),
sizeof
(
uint
64
_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
RUN_INFINI
(
infinirtMemcpy
(
result_cpu
.
data
()
,
result_buf
->
data
(),
sizeof
(
uint
32
_t
)
*
nreq
,
INFINIRT_MEMCPY_D2H
));
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
ans
[
req
]
=
(
uint32_t
)
result_cpu
[
req
];
ans
[
req
]
=
result_cpu
[
req
];
}
}
...
...
@@ -378,7 +375,13 @@ inferBatch(struct JiugeModel *model,
std
::
unique_lock
<
std
::
mutex
>
lock
(
model
->
states
[
idev
].
mtx
);
model
->
states
[
idev
].
proceed
=
true
;
lock
.
unlock
();
model
->
states
[
idev
].
cv
.
notify_one
();
model
->
states
[
idev
].
cv_start
.
notify_one
();
}
for
(
size_t
i
=
model
->
dev_ids
.
size
();
i
>
0
;
i
--
)
{
auto
idev
=
i
-
1
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
model
->
states
[
idev
].
mtx
);
model
->
states
[
idev
].
cv_done
.
wait
(
lock
,
[
&
]
{
return
!
(
model
->
states
[
idev
].
proceed
);
});
lock
.
unlock
();
}
}
...
...
@@ -387,7 +390,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
createDeviceResource
(
rsrc
,
&
meta
,
weights
,
device
,
idev
,
ndev
,
dev_id
,
comm
);
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
state
.
mtx
);
state
.
cv
.
wait
(
lock
,
[
&
]
{
return
state
.
proceed
||
state
.
exit_flag
;
});
state
.
cv
_start
.
wait
(
lock
,
[
&
]
{
return
state
.
proceed
||
state
.
exit_flag
;
});
if
(
state
.
exit_flag
)
{
break
;
}
...
...
@@ -396,6 +399,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
state
.
proceed
=
false
;
lock
.
unlock
();
state
.
cv_done
.
notify_one
();
}
infiniopDestroyHandle
(
rsrc
->
handle
);
...
...
@@ -403,8 +407,9 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
infinicclCommDestroy
(
rsrc
->
comm
);
}
JiugeModel
::
JiugeModel
(
const
JiugeMeta
*
_meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device
,
std
::
vector
<
int
>
device_ids
)
:
meta
(
*
_meta
)
{
JiugeModel
::
JiugeModel
(
const
JiugeMeta
*
_meta
,
const
JiugeWeights
*
weights
,
infiniDevice_t
device
_
,
std
::
vector
<
int
>
device_ids
)
:
meta
(
*
_meta
)
{
int
ndev
=
int
(
device_ids
.
size
());
device
=
device_
;
dev_ids
=
device_ids
;
dev_resources
=
std
::
vector
<
DeviceResource
>
(
ndev
);
states
=
std
::
vector
<
InferState
>
(
ndev
);
...
...
@@ -439,7 +444,7 @@ __C void destroyJiugeModel(struct JiugeModel *model) {
std
::
unique_lock
<
std
::
mutex
>
lock
(
model
->
states
[
idev
].
mtx
);
model
->
states
[
idev
].
exit_flag
=
true
;
lock
.
unlock
();
model
->
states
[
idev
].
cv
.
notify_one
();
model
->
states
[
idev
].
cv
_start
.
notify_one
();
}
for
(
size_t
idev
=
0
;
idev
<
ndev
;
idev
++
)
{
...
...
src/models/jiuge/jiuge_impl.hpp
View file @
967bcb64
...
...
@@ -28,7 +28,7 @@ struct DeviceResource {
struct
InferState
{
std
::
mutex
mtx
;
std
::
condition_variable
cv
;
std
::
condition_variable
cv
_start
,
cv_done
;
bool
proceed
=
false
;
bool
exit_flag
=
false
;
};
...
...
src/models/jiuge/jiuge_weight.hpp
View file @
967bcb64
...
...
@@ -106,34 +106,34 @@ inline std::shared_ptr<Tensor> getFFNDown(
}
inline
std
::
shared_ptr
<
Tensor
>
getSinTable
(
JiugeMeta
const
*
meta
)
{
float
*
table
=
(
float
*
)
std
::
malloc
(
meta
->
dctx
*
meta
->
dh
*
sizeof
(
float
));
auto
half_dh
=
meta
->
dh
/
2
;
uint16_t
*
table
=
(
uint16_t
*
)
std
::
malloc
(
meta
->
dctx
*
half_dh
*
sizeof
(
uint16_t
));
for
(
size_t
i
=
0
;
i
<
meta
->
dctx
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
half_dh
;
j
++
)
{
float
_sin
=
std
::
sin
(
static_cast
<
float
>
(
i
)
/
std
::
pow
(
meta
->
theta
,
static_cast
<
float
>
(
j
)
/
half_dh
));
table
[
i
*
meta
->
dh
+
2
*
j
]
=
_sin
;
table
[
i
*
meta
->
dh
+
2
*
j
+
1
]
=
_sin
;
table
[
i
*
half_dh
+
j
]
=
f32_to_f16
(
_sin
);
}
}
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dctx
,
meta
->
dh
});
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dctx
,
half_
dh
});
auto
tensor
=
Tensor
::
weight
(
table
,
meta
->
dt_logits
,
shape
);
std
::
free
(
table
);
return
tensor
;
}
inline
std
::
shared_ptr
<
Tensor
>
getCosTable
(
JiugeMeta
const
*
meta
)
{
float
*
table
=
(
float
*
)
std
::
malloc
(
meta
->
dctx
*
meta
->
dh
*
sizeof
(
float
));
auto
half_dh
=
meta
->
dh
/
2
;
uint16_t
*
table
=
(
uint16_t
*
)
std
::
malloc
(
meta
->
dctx
*
half_dh
*
sizeof
(
uint16_t
));
for
(
size_t
i
=
0
;
i
<
meta
->
dctx
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
half_dh
;
j
++
)
{
float
_cos
=
std
::
cos
(
static_cast
<
float
>
(
i
)
/
std
::
pow
(
meta
->
theta
,
static_cast
<
float
>
(
j
)
/
half_dh
));
table
[
i
*
meta
->
dh
+
2
*
j
]
=
_cos
;
table
[
i
*
meta
->
dh
+
2
*
j
+
1
]
=
_cos
;
table
[
i
*
half_dh
+
j
]
=
f32_to_f16
(
_cos
);
}
}
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dctx
,
meta
->
dh
});
auto
shape
=
std
::
vector
<
size_t
>
({
meta
->
dctx
,
half_
dh
});
auto
tensor
=
Tensor
::
weight
(
table
,
meta
->
dt_logits
,
shape
);
std
::
free
(
table
);
return
tensor
;
...
...
src/tensor.hpp
View file @
967bcb64
...
...
@@ -44,8 +44,7 @@ private:
std
::
vector
<
ptrdiff_t
>
_strides
;
void
*
_data
;
ptrdiff_t
_offset
;
size_t
_size
;
std
::
shared_ptr
<
Storage
>
storage
;
std
::
shared_ptr
<
Storage
>
_storage
;
infiniopTensorDescriptor_t
_desc
;
void
*
dataImpl
(
ptrdiff_t
offset
)
const
;
...
...
@@ -78,7 +77,6 @@ public:
size_t
ndim
()
const
;
infiniDtype_t
dtype
()
const
;
std
::
shared_ptr
<
TensorDesc
>
desc
()
const
;
size_t
byteSize
()
const
;
ptrdiff_t
dataOffset
()
const
;
infiniDevice_t
deviceType
()
const
;
int
deviceId
()
const
;
...
...
@@ -86,6 +84,7 @@ public:
void
debug
(
const
std
::
string
&
filename
)
const
;
void
debug
()
const
;
std
::
string
info
()
const
;
~
Tensor
();
};
...
...
src/tensor/tensor.cpp
View file @
967bcb64
...
...
@@ -3,6 +3,7 @@
#include <fstream>
#include <iostream>
#include <numeric>
#include <sstream>
std
::
shared_ptr
<
TensorDesc
>
TensorDesc
::
create
(
infiniDtype_t
dtype
,
const
std
::
vector
<
size_t
>
&
shape
,
...
...
@@ -21,13 +22,12 @@ const std::vector<size_t> &Tensor::shape() const { return this->_shape; }
const
std
::
vector
<
ptrdiff_t
>
&
Tensor
::
strides
()
const
{
return
this
->
_strides
;
}
size_t
Tensor
::
ndim
()
const
{
return
this
->
_shape
.
size
();
}
infiniDtype_t
Tensor
::
dtype
()
const
{
return
this
->
_dtype
;
}
size_t
Tensor
::
byteSize
()
const
{
return
this
->
_size
;
}
infiniDevice_t
Tensor
::
deviceType
()
const
{
return
this
->
storage
->
device_type
;
}
int
Tensor
::
deviceId
()
const
{
return
this
->
storage
->
device_id
;
}
infiniDevice_t
Tensor
::
deviceType
()
const
{
return
this
->
_storage
->
device_type
;
}
int
Tensor
::
deviceId
()
const
{
return
this
->
_storage
->
device_id
;
}
Tensor
::~
Tensor
()
{}
ptrdiff_t
Tensor
::
dataOffset
()
const
{
return
(
char
*
)(
this
->
_data
)
-
(
char
*
)(
this
->
storage
->
memory
)
;
return
_offset
;
}
std
::
shared_ptr
<
TensorDesc
>
Tensor
::
desc
()
const
{
return
TensorDesc
::
create
(
this
->
_dtype
,
this
->
_shape
,
this
->
_strides
);
}
...
...
@@ -38,22 +38,19 @@ std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_dtype
=
dtype
;
auto
ndim
=
shape
.
size
();
if
(
shape
.
empty
())
{
tensor
->
_shape
=
std
::
vector
<
size_t
>
{
1
};
ndim
=
1
;
}
else
{
tensor
->
_shape
=
std
::
vector
<
size_t
>
(
shape
);
}
tensor
->
_shape
=
std
::
vector
<
size_t
>
(
shape
);
size_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
dsize
(
dtype
),
std
::
multiplies
<
size_t
>
());
auto
strides
=
std
::
vector
<
ptrdiff_t
>
(
ndim
);
strides
[
ndim
-
1
]
=
1
;
for
(
int
i
=
ndim
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
shape
[
i
+
1
];
if
(
ndim
>
0
)
{
strides
[
ndim
-
1
]
=
1
;
for
(
int
i
=
ndim
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
shape
[
i
+
1
];
}
}
tensor
->
_strides
=
strides
;
tensor
->
storage
=
Storage
::
createAsync
(
size
,
stream
);
tensor
->
_size
=
size
;
tensor
->
_data
=
tensor
->
storage
->
memory
;
tensor
->
_storage
=
Storage
::
createAsync
(
size
,
stream
);
tensor
->
_data
=
tensor
->
_storage
->
memory
;
infiniopCreateTensorDescriptor
(
&
tensor
->
_desc
,
ndim
,
tensor
->
_shape
.
data
(),
strides
.
data
(),
dtype
);
tensor
->
_offset
=
0
;
...
...
@@ -66,24 +63,20 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
;
tensor
->
_dtype
=
dtype
;
auto
ndim
=
shape
.
size
();
if
(
shape
.
empty
())
{
tensor
->
_shape
=
std
::
vector
<
size_t
>
{
1
};
ndim
=
1
;
}
else
{
tensor
->
_shape
=
std
::
vector
<
size_t
>
(
shape
);
}
tensor
->
_shape
=
std
::
vector
<
size_t
>
(
shape
);
size_t
size
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
dsize
(
dtype
),
std
::
multiplies
<
size_t
>
());
auto
strides
=
std
::
vector
<
ptrdiff_t
>
(
ndim
);
strides
[
ndim
-
1
]
=
1
;
for
(
int
i
=
ndim
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
shape
[
i
+
1
];
if
(
ndim
>
0
)
{
strides
[
ndim
-
1
]
=
1
;
for
(
int
i
=
ndim
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
shape
[
i
+
1
];
}
}
tensor
->
_strides
=
strides
;
tensor
->
storage
=
Storage
::
create
(
size
);
RUN_INFINI
(
infinirtMemcpy
(
tensor
->
storage
->
memory
,
tensor
->
_
storage
=
Storage
::
create
(
size
);
RUN_INFINI
(
infinirtMemcpy
(
tensor
->
_
storage
->
memory
,
data
,
size
,
INFINIRT_MEMCPY_H2D
));
tensor
->
_data
=
tensor
->
storage
->
memory
;
tensor
->
_size
=
size
;
tensor
->
_data
=
tensor
->
_storage
->
memory
;
infiniopCreateTensorDescriptor
(
&
tensor
->
_desc
,
ndim
,
tensor
->
_shape
.
data
(),
strides
.
data
(),
dtype
);
tensor
->
_offset
=
0
;
...
...
@@ -91,8 +84,6 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
}
void
*
Tensor
::
dataImpl
(
ptrdiff_t
offset
)
const
{
ASSERT
(
offset
*
dsize
(
this
->
dtype
())
<
this
->
_size
);
return
(
char
*
)(
this
->
_data
)
+
offset
*
dsize
(
this
->
dtype
());
}
...
...
@@ -139,7 +130,6 @@ void print_data(T *data, const std::vector<size_t> &shape,
}
else
if
(
dim
<
shape
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
shape
[
dim
];
i
++
)
{
print_data
(
data
+
i
*
strides
[
dim
],
shape
,
strides
,
dim
+
1
);
std
::
cout
<<
std
::
endl
;
}
}
}
...
...
@@ -151,38 +141,46 @@ void print_data(uint16_t const *data, const std::vector<size_t> &shape,
for
(
size_t
i
=
0
;
i
<
shape
[
dim
];
i
++
)
{
std
::
cout
<<
f16_to_f32
(
data
[
i
*
strides
[
dim
]])
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
else
if
(
dim
<
shape
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
shape
[
dim
];
i
++
)
{
print_data
(
data
+
i
*
strides
[
dim
],
shape
,
strides
,
dim
+
1
);
std
::
cout
<<
std
::
endl
;
}
}
}
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
s
td
::
cout
<<
"Tensor: "
<<
"shape[ "
;
std
::
string
Tensor
::
info
(
)
const
{
std
::
stringstream
ss
;
s
s
<<
"Tensor: "
<<
"shape[ "
;
for
(
auto
s
:
this
->
shape
())
{
s
td
::
cout
<<
s
<<
" "
;
s
s
<<
s
<<
" "
;
}
s
td
::
cout
<<
"] strides[ "
;
s
s
<<
"] strides[ "
;
for
(
auto
s
:
this
->
strides
())
{
s
td
::
cout
<<
s
<<
" "
;
s
s
<<
s
<<
" "
;
}
std
::
cout
<<
"] dtype="
<<
this
->
dtype
()
<<
" device="
<<
this
->
deviceType
()
<<
" device_id="
<<
this
->
deviceId
()
<<
std
::
endl
;
ss
<<
"] dtype="
<<
this
->
dtype
()
<<
" device="
<<
this
->
deviceType
()
<<
" device_id="
<<
this
->
deviceId
();
return
ss
.
str
();
}
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
std
::
cout
<<
info
()
<<
std
::
endl
;
auto
dtype
=
this
->
dtype
();
void
const
*
cpu_data
;
if
(
this
->
deviceType
()
!=
INFINI_DEVICE_CPU
)
{
void
*
cpu_memory
=
std
::
malloc
(
this
->
storage
->
size
);
RUN_INFINI
(
infinirtMemcpy
(
cpu_memory
,
this
->
storage
->
memory
,
this
->
storage
->
size
,
INFINIRT_MEMCPY_D2H
));
void
*
cpu_memory
=
std
::
malloc
(
this
->
_
storage
->
size
);
RUN_INFINI
(
infinirtMemcpy
(
cpu_memory
,
this
->
_
storage
->
memory
,
this
->
_
storage
->
size
,
INFINIRT_MEMCPY_D2H
));
cpu_data
=
cpu_memory
;
}
else
{
cpu_data
=
this
->
data
()
;
cpu_data
=
this
->
_storage
->
memory
;
}
if
(
!
filename
.
empty
())
{
...
...
@@ -191,7 +189,7 @@ void Tensor::debug(const std::string &filename) const {
std
::
cerr
<<
"Error opening file for writing: "
<<
filename
<<
"
\n
"
;
return
;
}
outFile
.
write
(
reinterpret_cast
<
const
char
*>
(
cpu_data
),
this
->
storage
->
size
);
outFile
.
write
(
reinterpret_cast
<
const
char
*>
(
cpu_data
),
this
->
_
storage
->
size
);
outFile
.
close
();
std
::
cout
<<
"Data written to file: "
<<
filename
<<
"
\n
"
;
return
;
...
...
src/tensor/transform.cpp
View file @
967bcb64
...
...
@@ -20,12 +20,10 @@ std::shared_ptr<Tensor> Tensor::sliceImpl(const std::vector<SliceParams> &slices
tensor
->
_dtype
=
this
->
_dtype
;
tensor
->
_shape
=
new_shape
;
tensor
->
_strides
=
std
::
vector
<
ptrdiff_t
>
(
this
->
_strides
);
tensor
->
_offset
=
offset
*
dsize
(
this
->
_dtype
);
tensor
->
_data
=
static_cast
<
char
*
>
(
this
->
_
data
)
+
tensor
->
_offset
;
tensor
->
_offset
=
offset
*
dsize
(
this
->
_dtype
)
+
this
->
_offset
;
tensor
->
_data
=
(
char
*
)
(
this
->
_
storage
->
memory
)
+
tensor
->
_offset
;
tensor
->
_size
=
std
::
accumulate
(
new_shape
.
begin
(),
new_shape
.
end
(),
dsize
(
this
->
_dtype
),
std
::
multiplies
<
size_t
>
());
tensor
->
storage
=
this
->
storage
;
tensor
->
_storage
=
this
->
_storage
;
infiniopCreateTensorDescriptor
(
&
tensor
->
_desc
,
tensor
->
_shape
.
size
(),
tensor
->
_shape
.
data
(),
tensor
->
_strides
.
data
(),
tensor
->
_dtype
);
return
tensor
;
...
...
src/utils.hpp
View file @
967bcb64
...
...
@@ -2,6 +2,7 @@
#define INFINICORE_INFER_UTILS_H
#include <infinicore.h>
#include <cstring>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
...
...
@@ -70,4 +71,30 @@ inline float f16_to_f32(uint16_t h) {
}
}
inline
uint16_t
f32_to_f16
(
float
val
)
{
uint32_t
f32
;
memcpy
(
&
f32
,
&
val
,
sizeof
(
f32
));
// Read the bits of the float32
uint16_t
sign
=
(
f32
>>
16
)
&
0x8000
;
// Extract the sign bit
int32_t
exponent
=
((
f32
>>
23
)
&
0xFF
)
-
127
;
// Extract and de-bias the exponent
uint32_t
mantissa
=
f32
&
0x7FFFFF
;
// Extract the mantissa (fraction part)
if
(
exponent
>=
31
)
{
// Special cases for Inf and NaN
// NaN
if
(
exponent
==
128
&&
mantissa
!=
0
)
{
return
static_cast
<
uint16_t
>
(
sign
|
0x7E00
);
}
// Infinity
return
static_cast
<
uint16_t
>
(
sign
|
0x7C00
);
}
else
if
(
exponent
>=
-
14
)
{
// Normalized case
return
(
uint16_t
)(
sign
|
((
exponent
+
15
)
<<
10
)
|
(
mantissa
>>
13
));
}
else
if
(
exponent
>=
-
24
)
{
mantissa
|=
0x800000
;
// Add implicit leading 1
mantissa
>>=
(
-
14
-
exponent
);
return
(
uint16_t
)(
sign
|
(
mantissa
>>
13
));
}
else
{
// Too small for subnormal: return signed zero
return
(
uint16_t
)
sign
;
}
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment