Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
79bc0438
Commit
79bc0438
authored
Jul 23, 2025
by
wooway777
Browse files
issue/21 - replaced most dim operations and added a linear layer
parent
366386d3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
145 additions
and
50 deletions
+145
-50
src/models/cache_manager.hpp
src/models/cache_manager.hpp
+16
-1
src/models/inference_context.cpp
src/models/inference_context.cpp
+46
-0
src/models/inference_context.hpp
src/models/inference_context.hpp
+9
-0
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+27
-36
src/tensor.hpp
src/tensor.hpp
+4
-2
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+43
-0
src/tensor/transform.cpp
src/tensor/transform.cpp
+0
-11
No files found.
src/models/cache_manager.hpp
View file @
79bc0438
...
@@ -35,6 +35,7 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
...
@@ -35,6 +35,7 @@ inline size_t computeTensorDescHash(std::shared_ptr<Tensor> tensor) {
}
}
enum
class
OperatorType
{
enum
class
OperatorType
{
ADD
,
RMS_NORM
,
RMS_NORM
,
GEMM
,
GEMM
,
ROPE
,
ROPE
,
...
@@ -66,6 +67,9 @@ private:
...
@@ -66,6 +67,9 @@ private:
void
destroyDescriptor
(
DescriptorType
&
desc
)
{
void
destroyDescriptor
(
DescriptorType
&
desc
)
{
switch
(
opType
)
{
switch
(
opType
)
{
case
OperatorType
::
ADD
:
infiniopDestroyAddDescriptor
(
desc
);
break
;
case
OperatorType
::
RMS_NORM
:
case
OperatorType
::
RMS_NORM
:
infiniopDestroyRMSNormDescriptor
(
desc
);
infiniopDestroyRMSNormDescriptor
(
desc
);
break
;
break
;
...
@@ -178,6 +182,7 @@ class CacheManager {
...
@@ -178,6 +182,7 @@ class CacheManager {
private:
private:
const
size_t
DEFAULT_CACHE_CAPACITY
=
128
;
const
size_t
DEFAULT_CACHE_CAPACITY
=
128
;
LRUDescriptorCache
<
infiniopAddDescriptor_t
>
add_cache
;
LRUDescriptorCache
<
infiniopRMSNormDescriptor_t
>
rms_norm_cache
;
LRUDescriptorCache
<
infiniopRMSNormDescriptor_t
>
rms_norm_cache
;
LRUDescriptorCache
<
infiniopGemmDescriptor_t
>
gemm_cache
;
LRUDescriptorCache
<
infiniopGemmDescriptor_t
>
gemm_cache
;
LRUDescriptorCache
<
infiniopRoPEDescriptor_t
>
rope_cache
;
LRUDescriptorCache
<
infiniopRoPEDescriptor_t
>
rope_cache
;
...
@@ -187,7 +192,8 @@ private:
...
@@ -187,7 +192,8 @@ private:
LRUDescriptorCache
<
infiniopRandomSampleDescriptor_t
>
random_sample_cache
;
LRUDescriptorCache
<
infiniopRandomSampleDescriptor_t
>
random_sample_cache
;
public:
public:
CacheManager
(
size_t
capacity
=
100
)
:
rms_norm_cache
(
capacity
,
OperatorType
::
RMS_NORM
),
CacheManager
(
size_t
capacity
=
100
)
:
add_cache
(
capacity
,
OperatorType
::
ADD
),
rms_norm_cache
(
capacity
,
OperatorType
::
RMS_NORM
),
gemm_cache
(
capacity
,
OperatorType
::
GEMM
),
gemm_cache
(
capacity
,
OperatorType
::
GEMM
),
rope_cache
(
capacity
,
OperatorType
::
ROPE
),
rope_cache
(
capacity
,
OperatorType
::
ROPE
),
rearrange_cache
(
capacity
,
OperatorType
::
REARRANGE
),
rearrange_cache
(
capacity
,
OperatorType
::
REARRANGE
),
...
@@ -195,6 +201,15 @@ public:
...
@@ -195,6 +201,15 @@ public:
swiglu_cache
(
capacity
,
OperatorType
::
SWIGLU
),
swiglu_cache
(
capacity
,
OperatorType
::
SWIGLU
),
random_sample_cache
(
capacity
,
OperatorType
::
RANDOM_SAMPLE
)
{}
random_sample_cache
(
capacity
,
OperatorType
::
RANDOM_SAMPLE
)
{}
// Add operations
bool
getAddDescriptor
(
size_t
key
,
infiniopAddDescriptor_t
&
desc
)
{
return
add_cache
.
get
(
key
,
desc
);
}
void
putAddDescriptor
(
size_t
key
,
const
infiniopAddDescriptor_t
&
desc
)
{
add_cache
.
put
(
key
,
desc
);
}
// RMSNorm operations
// RMSNorm operations
bool
getRMSNormDescriptor
(
size_t
key
,
infiniopRMSNormDescriptor_t
&
desc
)
{
bool
getRMSNormDescriptor
(
size_t
key
,
infiniopRMSNormDescriptor_t
&
desc
)
{
return
rms_norm_cache
.
get
(
key
,
desc
);
return
rms_norm_cache
.
get
(
key
,
desc
);
...
...
src/models/inference_context.cpp
View file @
79bc0438
...
@@ -12,6 +12,28 @@ void InferenceContext::ensure_workspace(size_t required_size) {
...
@@ -12,6 +12,28 @@ void InferenceContext::ensure_workspace(size_t required_size) {
}
}
}
}
void
InferenceContext
::
add
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
)
{
size_t
key
=
CacheManager
::
createDescriptorKey
(
c
,
a
,
b
,
nullptr
,
nullptr
);
infiniopAddDescriptor_t
desc
;
if
(
!
cache_manager
->
getAddDescriptor
(
key
,
desc
))
{
RUN_INFINI
(
infiniopCreateAddDescriptor
(
rsrc
->
handle
,
&
desc
,
c
->
desc
(),
a
->
desc
(),
b
->
desc
()));
cache_manager
->
putAddDescriptor
(
key
,
desc
);
}
size_t
workspace_size
=
0
;
RUN_INFINI
(
infiniopGetAddWorkspaceSize
(
desc
,
&
workspace_size
));
ensure_workspace
(
workspace_size
);
void
*
workspace
=
workspace_storage
->
memory
();
RUN_INFINI
(
infiniopAdd
(
desc
,
workspace
,
workspace_size
,
c
->
data
(),
a
->
data
(),
b
->
data
(),
stream
));
}
void
InferenceContext
::
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
void
InferenceContext
::
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
std
::
shared_ptr
<
Tensor
>
w
,
...
@@ -165,3 +187,27 @@ void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
...
@@ -165,3 +187,27 @@ void InferenceContext::randomSample(std::shared_ptr<Tensor> out,
random_val
,
top_p
,
top_k
,
temperature
,
random_val
,
top_p
,
top_k
,
temperature
,
stream
));
stream
));
}
}
void
InferenceContext
::
linear
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
)
{
if
(
residual
)
{
if
(
residual
->
data
()
==
c
->
data
())
{
if
(
beta
==
0.0
)
{
gemm
(
c
,
a
,
b
,
alpha
,
1.0
);
}
else
{
auto
c_copy
=
Tensor
::
buffer
(
c
->
dtype
(),
c
->
shape
(),
rsrc
->
memory_pool
);
c_copy
->
copyFrom
(
c
,
rsrc
->
handle
,
stream
);
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
add
(
c
,
c
,
c_copy
);
}
}
else
{
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
add
(
c
,
c
,
residual
);
}
}
else
{
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
}
}
src/models/inference_context.hpp
View file @
79bc0438
...
@@ -16,6 +16,9 @@ struct InferenceContext {
...
@@ -16,6 +16,9 @@ struct InferenceContext {
void
ensure_workspace
(
size_t
required_size
);
void
ensure_workspace
(
size_t
required_size
);
void
add
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
);
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
void
rmsnorm
(
std
::
shared_ptr
<
Tensor
>
y
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
x
,
std
::
shared_ptr
<
Tensor
>
w
,
std
::
shared_ptr
<
Tensor
>
w
,
...
@@ -39,4 +42,10 @@ struct InferenceContext {
...
@@ -39,4 +42,10 @@ struct InferenceContext {
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
void
randomSample
(
std
::
shared_ptr
<
Tensor
>
out
,
std
::
shared_ptr
<
Tensor
>
prob
,
std
::
shared_ptr
<
Tensor
>
prob
,
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
);
float
random_val
,
float
top_p
,
uint32_t
top_k
,
float
temperature
);
void
linear
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
);
};
};
src/models/jiuge/jiuge.cpp
View file @
79bc0438
...
@@ -166,7 +166,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -166,7 +166,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
}
// Attention
// Attention
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
// (ntok, nh + 2 * nkvh, dh)
// attention inner
// attention inner
size_t
max_qk_size
=
0
;
size_t
max_qk_size
=
0
;
size_t
max_seq_len
=
0
;
size_t
max_seq_len
=
0
;
...
@@ -194,56 +193,49 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -194,56 +193,49 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// rms norm
// rms norm
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
// qkv_proj
qkv_buf
->
dimMerge
(
1
,
2
);
if
(
has_qkv_bias
)
{
if
(
has_qkv_bias
)
{
ctx
.
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
reDesc
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
}));
ctx
.
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
view
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
}));
}
}
ctx
.
gemm
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
has_qkv_bias
?
1.0
:
0.0
);
ctx
.
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
// rope
// rope
qkv_buf
->
dimSplit
(
1
,
{
nh
+
nkvh
*
2
,
dh
});
auto
qkv_rope
=
qkv_buf
->
viewReshaped
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
ctx
.
rope
(
qkv_
buf
->
slice
(
1
,
0
,
nh
),
qkv_
buf
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_
rope
->
slice
(
1
,
0
,
nh
),
qkv_
rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_
buf
->
slice
(
1
,
nh
,
nkvh
),
qkv_
buf
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
ctx
.
rope
(
qkv_
rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_
rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
size_t
token_offset
=
0
;
size_t
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
past_len
=
req_pos
[
req
];
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
dimSplit
(
1
,
{
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}});
auto
o
=
o_buf
->
viewReshaped
({
ntok
,
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
})
;
auto
q
=
qkv_
buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}});
auto
q
=
qkv_
rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
})
;
auto
k
=
qkv_
buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
k
=
qkv_
rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_
buf
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
auto
v
=
qkv_
rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
// self attention
// self attention
// concat
// concat
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
k
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
ctx
.
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
// qk
ctx
.
rearrange
(
rearrange_q_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
}),
auto
q_rearrange
=
rearrange_q_buf
->
viewReshaped
({
nkvh
,
ngroup
,
seq_len
,
dh
});
q
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
}));
ctx
.
rearrange
(
q_rearrange
,
q
);
qk_buf
->
dimSplit
(
1
,
{
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
viewReshaped
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
qk_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
qk_buf
->
dimMerge
(
1
,
2
);
ctx
.
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
);
ctx
.
gemm
(
qk_buf
,
rearrange_q_buf
->
dimMerge
(
1
,
2
),
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
}),
1.
/
sqrt
(
dh
),
0.0
);
// softmax
// softmax
qk_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
});
auto
qk_softmax
=
qk_buf
->
viewReshaped
({
nh
,
seq_len
,
total_len
});
qk_buf
->
dimMerge
(
0
,
1
);
ctx
.
causalSoftmax
(
qk_softmax
,
qk_softmax
);
ctx
.
causalSoftmax
(
qk_buf
,
qk_buf
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
qk_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
ctx
.
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
qk_buf
->
dimMerge
(
1
,
2
);
ctx
.
gemm
(
attn_val_buf
,
qk_buf
,
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
}),
1.0
,
0.0
);
qk_buf
->
dimSplit
(
1
,
{
ngroup
,
seq_len
});
qk_buf
->
dimMerge
(
2
,
3
);
qk_buf
->
dimMerge
(
0
,
1
);
// rearrange attn val
// rearrange attn val
attn_val_buf
->
dimSplit
(
0
,
{
nkvh
,
ngroup
});
auto
attn_val_gemm
=
attn_val_buf
->
viewReshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
ctx
.
rearrange
(
o
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
}),
attn_val_buf
);
ctx
.
rearrange
(
o
,
attn_val_gemm
);
attn_val_buf
->
dimMerge
(
0
,
1
);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
// o_proj
// o_proj
ctx
.
gemm
(
logits_in
,
o_buf
->
dimMerge
(
1
,
2
)
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
ctx
.
linear
(
logits_in
,
o_buf
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -253,11 +245,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -253,11 +245,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
RUN_INFINI
(
infinirtStreamSynchronize
(
stream
));
}
}
// 2. FFN
// 2. FFN
// rms_norm
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
ctx
.
gemm
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
);
ctx
.
linear
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
,
nullptr
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
ctx
.
gemm
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
idev
==
0
?
1.0
:
0.0
);
// only rank 0 adds residual
ctx
.
linear
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -278,15 +269,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -278,15 +269,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc
.
w_out_norm
,
rsrc
.
w_out_norm
,
meta
.
epsilon
);
meta
.
epsilon
);
}
}
ctx
.
gemm
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
);
ctx
.
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
);
std
::
random_device
_rd
;
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
ctx
.
randomSample
(
result_buf
->
reDesc
({},
{}),
ctx
.
randomSample
(
result_buf
->
view
({},
{}),
prob_buf
->
reDesc
({
dvoc
},
{
1
}),
prob_buf
->
view
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
...
...
src/tensor.hpp
View file @
79bc0438
...
@@ -78,7 +78,6 @@ public:
...
@@ -78,7 +78,6 @@ public:
void
dimMerge
(
size_t
dim_start
,
size_t
dim_end
);
void
dimMerge
(
size_t
dim_start
,
size_t
dim_end
);
void
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
void
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
void
permute
(
const
std
::
vector
<
size_t
>
&
order
);
void
permute
(
const
std
::
vector
<
size_t
>
&
order
);
void
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
);
};
};
class
Tensor
:
public
std
::
enable_shared_from_this
<
Tensor
>
{
class
Tensor
:
public
std
::
enable_shared_from_this
<
Tensor
>
{
...
@@ -111,7 +110,6 @@ public:
...
@@ -111,7 +110,6 @@ public:
std
::
shared_ptr
<
Tensor
>
dimSplit
(
size_t
dim
,
std
::
shared_ptr
<
Tensor
>
dimSplit
(
size_t
dim
,
const
std
::
vector
<
size_t
>
&
dims
);
const
std
::
vector
<
size_t
>
&
dims
);
std
::
shared_ptr
<
Tensor
>
permute
(
const
std
::
vector
<
size_t
>
&
order
);
std
::
shared_ptr
<
Tensor
>
permute
(
const
std
::
vector
<
size_t
>
&
order
);
std
::
shared_ptr
<
Tensor
>
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
);
void
*
data
(
ptrdiff_t
offset
=
0
);
void
*
data
(
ptrdiff_t
offset
=
0
);
void
const
*
data
(
ptrdiff_t
offset
=
0
)
const
;
void
const
*
data
(
ptrdiff_t
offset
=
0
)
const
;
void
copyFrom
(
std
::
shared_ptr
<
Tensor
const
>
src
,
infiniopHandle_t
handle
,
void
copyFrom
(
std
::
shared_ptr
<
Tensor
const
>
src
,
infiniopHandle_t
handle
,
...
@@ -130,6 +128,10 @@ public:
...
@@ -130,6 +128,10 @@ public:
void
debug
()
const
;
void
debug
()
const
;
std
::
string
info
()
const
;
std
::
string
info
()
const
;
std
::
shared_ptr
<
Tensor
>
view
()
const
;
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
viewReshaped
(
const
std
::
vector
<
size_t
>
new_shape
)
const
;
~
Tensor
();
~
Tensor
();
};
};
...
...
src/tensor/tensor.cpp
View file @
79bc0438
...
@@ -258,6 +258,49 @@ std::string Tensor::info() const {
...
@@ -258,6 +258,49 @@ std::string Tensor::info() const {
return
this
->
_desc
->
info
();
return
this
->
_desc
->
info
();
}
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
()
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
this
->
shape
(),
this
->
strides
());
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
new_strides
);
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
viewReshaped
(
const
std
::
vector
<
size_t
>
new_shape
)
const
{
// First validate that the total number of elements matches
size_t
current_elements
=
std
::
accumulate
(
_desc
->
shape
().
begin
(),
_desc
->
shape
().
end
(),
1
,
std
::
multiplies
<
size_t
>
());
size_t
new_elements
=
std
::
accumulate
(
new_shape
.
begin
(),
new_shape
.
end
(),
1
,
std
::
multiplies
<
size_t
>
());
ASSERT_EQ
(
current_elements
,
new_elements
);
// Create a copy of the current shape and strides
auto
current_shape
=
_desc
->
shape
();
// Start with the current tensor
auto
result
=
this
->
view
();
// Step 1: Merge all dimensions (if there are more than 1)
if
(
current_shape
.
size
()
>
1
)
{
result
=
result
->
dimMerge
(
0
,
current_shape
.
size
()
-
1
);
}
// Step 2: Split into the new shape
if
(
new_shape
.
size
()
>
1
)
{
result
=
result
->
dimSplit
(
0
,
new_shape
);
}
return
result
;
}
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
RUN_INFINI
(
infinirtDeviceSynchronize
());
...
...
src/tensor/transform.cpp
View file @
79bc0438
...
@@ -114,14 +114,3 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
...
@@ -114,14 +114,3 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this
->
_desc
->
permute
(
order
);
this
->
_desc
->
permute
(
order
);
return
shared_from_this
();
return
shared_from_this
();
}
}
void
TensorDesc
::
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
{
this
->
_shape
=
new_shape
;
this
->
_strides
=
new_strides
;
this
->
resetDesc
();
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
reDesc
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
{
this
->
_desc
->
reDesc
(
new_shape
,
new_strides
);
return
shared_from_this
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment