Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
6998a8f1
Commit
6998a8f1
authored
Aug 08, 2025
by
wooway777
Browse files
issue/21 - Improved linear and view implementations
parent
8853663e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
42 additions
and
43 deletions
+42
-43
src/models/cache_manager.hpp
src/models/cache_manager.hpp
+1
-1
src/models/inference_context.cpp
src/models/inference_context.cpp
+11
-1
src/models/inference_context.hpp
src/models/inference_context.hpp
+4
-3
src/models/jiuge/jiuge.cpp
src/models/jiuge/jiuge.cpp
+15
-18
src/tensor.hpp
src/tensor.hpp
+2
-3
src/tensor/tensor.cpp
src/tensor/tensor.cpp
+9
-17
No files found.
src/models/cache_manager.hpp
View file @
6998a8f1
...
@@ -22,7 +22,7 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
...
@@ -22,7 +22,7 @@ inline void hash_combine(size_t &seed, T value, typename std::enable_if<std::is_
}
}
// Helper function to compute hash for tensor descriptors
// Helper function to compute hash for tensor descriptors
inline
size_t
computeTensorDescHash
(
std
::
shared_ptr
<
Tensor
>
tensor
)
{
inline
size_t
computeTensorDescHash
(
std
::
shared_ptr
<
Tensor
>
&
tensor
)
{
size_t
seed
=
0
;
size_t
seed
=
0
;
hash_combine
(
seed
,
tensor
->
dtype
());
hash_combine
(
seed
,
tensor
->
dtype
());
for
(
auto
dim
:
tensor
->
shape
())
{
for
(
auto
dim
:
tensor
->
shape
())
{
...
...
src/models/inference_context.cpp
View file @
6998a8f1
...
@@ -190,7 +190,8 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
...
@@ -190,7 +190,8 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
)
{
std
::
shared_ptr
<
Tensor
>
residual
,
std
::
shared_ptr
<
Tensor
>
bias
)
{
if
(
residual
)
{
if
(
residual
)
{
if
(
residual
->
data
()
==
c
->
data
())
{
if
(
residual
->
data
()
==
c
->
data
())
{
if
(
beta
==
0.0
)
{
if
(
beta
==
0.0
)
{
...
@@ -208,4 +209,13 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
...
@@ -208,4 +209,13 @@ void InferenceContext::linear(std::shared_ptr<Tensor> c,
}
else
{
}
else
{
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
gemm
(
c
,
a
,
b
,
alpha
,
beta
);
}
}
if
(
bias
)
{
int
ndim_diff
=
c
->
ndim
()
-
1
;
ASSERT_EQ
(
bias
->
ndim
(),
1
);
ASSERT_EQ
(
bias
->
shape
()[
0
],
c
->
shape
()[
ndim_diff
]);
std
::
vector
<
ptrdiff_t
>
strides
(
ndim_diff
,
0
);
strides
.
push_back
(
bias
->
strides
()[
0
]);
add
(
c
,
c
,
bias
->
view_as
(
c
->
shape
(),
strides
));
}
}
}
src/models/inference_context.hpp
View file @
6998a8f1
...
@@ -47,7 +47,8 @@ struct InferenceContext {
...
@@ -47,7 +47,8 @@ struct InferenceContext {
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
);
std
::
shared_ptr
<
Tensor
>
residual
,
std
::
shared_ptr
<
Tensor
>
bias
);
};
};
namespace
{
namespace
{
...
@@ -103,6 +104,6 @@ inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> pr
...
@@ -103,6 +104,6 @@ inline void randomSample(std::shared_ptr<Tensor> out, std::shared_ptr<Tensor> pr
inline
void
linear
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
inline
void
linear
(
std
::
shared_ptr
<
Tensor
>
c
,
std
::
shared_ptr
<
Tensor
>
a
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
b
,
float
alpha
,
float
beta
,
std
::
shared_ptr
<
Tensor
>
residual
)
{
std
::
shared_ptr
<
Tensor
>
residual
,
std
::
shared_ptr
<
Tensor
>
bias
)
{
getInferenceContext
().
linear
(
c
,
a
,
b
,
alpha
,
beta
,
residual
);
getInferenceContext
().
linear
(
c
,
a
,
b
,
alpha
,
beta
,
residual
,
bias
);
}
}
src/models/jiuge/jiuge.cpp
View file @
6998a8f1
...
@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -141,7 +141,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_buf
=
Tensor
::
buffer
(
INFINI_DTYPE_I64
,
{
nreq
},
rsrc
.
memory_pool
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
result_cpu
=
std
::
vector
<
int64_t
>
(
nreq
);
auto
qkv_rope
=
qkv_buf
->
view
Reshaped
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
auto
qkv_rope
=
qkv_buf
->
view
({
ntok
,
nh
+
nkvh
*
2
,
dh
});
// Prepare inputs
// Prepare inputs
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
auto
batch_pos_ids
=
std
::
vector
<
uint32_t
>
(
ntok
);
...
@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -183,9 +183,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
qk_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nh
,
max_qk_size
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
rearrange_q_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
q_rearrange
=
rearrange_q_buf
->
view
Reshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
q_rearrange
=
rearrange_q_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_buf
=
Tensor
::
buffer
(
dt_logits
,
{
nkvh
,
ngroup
*
max_seq_len
,
dh
},
rsrc
.
memory_pool
);
auto
attn_val_gemm
=
attn_val_buf
->
view
Reshaped
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
auto
attn_val_gemm
=
attn_val_buf
->
view
({
nkvh
,
ngroup
,
max_seq_len
,
dh
});
// MLP buffers
// MLP buffers
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
auto
gate_buf
=
gate_up_buf
->
slice
(
1
,
0
,
di
);
...
@@ -197,10 +197,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -197,10 +197,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// rms norm
// rms norm
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_attn_norm
[
layer
],
meta
.
epsilon
);
// qkv_proj
// qkv_proj
if
(
has_qkv_bias
)
{
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
nullptr
,
has_qkv_bias
?
rsrc
.
b_attn_qkv
[
layer
]
:
nullptr
);
rearrange
(
qkv_buf
,
rsrc
.
b_attn_qkv
[
layer
]
->
view
({
ntok
,
(
nh
+
nkvh
*
2
)
*
dh
},
{
0
,
1
}));
}
linear
(
qkv_buf
,
logits_out
,
rsrc
.
w_attn_qkv
[
layer
],
1.0
,
0.0
,
has_qkv_bias
?
qkv_buf
:
nullptr
);
// rope
// rope
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
0
,
nh
),
qkv_rope
->
slice
(
1
,
0
,
nh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
rope
(
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
qkv_rope
->
slice
(
1
,
nh
,
nkvh
),
pos_ids_buf
,
rsrc
.
sin_table
,
rsrc
.
cos_table
);
...
@@ -210,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -210,7 +207,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto
past_len
=
req_pos
[
req
];
auto
past_len
=
req_pos
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
auto
total_len
=
past_len
+
seq_len
;
auto
total_len
=
past_len
+
seq_len
;
auto
o
=
o_buf
->
view
Reshaped
({
ntok
,
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
o
=
o_buf
->
view
({
ntok
,
nh
,
dh
})
->
slice
({{
0
,
token_offset
,
seq_len
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
q
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
0
,
nh
}})
->
dimSplit
(
1
,
{
nkvh
,
ngroup
})
->
permute
({
1
,
2
,
0
,
3
});
auto
k
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
k
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
,
nkvh
}});
auto
v
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
auto
v
=
qkv_rope
->
slice
({{
0
,
token_offset
,
seq_len
},
{
1
,
nh
+
nkvh
,
nkvh
}});
...
@@ -221,14 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -221,14 +218,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
rearrange
(
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
past_len
,
seq_len
),
v
);
// qk
// qk
rearrange
(
q_rearrange
,
q
);
rearrange
(
q_rearrange
,
q
);
auto
qk_gemm
=
qk_buf
->
view
Reshaped
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
qk_gemm
=
qk_buf
->
view
({
nkvh
,
ngroup
*
seq_len
,
total_len
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
auto
k_gemm
=
kv_caches
[
req
]
->
k
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
2
,
0
});
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
);
linear
(
qk_gemm
,
rearrange_q_buf
,
k_gemm
,
1.
/
sqrt
(
dh
),
0.0
,
nullptr
,
nullptr
);
// softmax
// softmax
auto
qk_softmax
=
qk_buf
->
view
Reshaped
({
nh
,
seq_len
,
total_len
});
auto
qk_softmax
=
qk_buf
->
view
({
nh
,
seq_len
,
total_len
});
causalSoftmax
(
qk_softmax
,
qk_softmax
);
causalSoftmax
(
qk_softmax
,
qk_softmax
);
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
auto
v_gemm
=
kv_caches
[
req
]
->
v
[
idev
][
layer
]
->
slice
(
0
,
0
,
total_len
)
->
permute
({
1
,
0
,
2
});
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
);
linear
(
attn_val_buf
,
qk_gemm
,
v_gemm
,
1.0
,
0.0
,
nullptr
,
nullptr
);
// rearrange attn val
// rearrange attn val
rearrange
(
o
,
attn_val_gemm
);
rearrange
(
o
,
attn_val_gemm
);
...
@@ -236,7 +233,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -236,7 +233,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
}
// o_proj
// o_proj
linear
(
logits_in
,
o_buf
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
linear
(
logits_in
,
o_buf
,
rsrc
.
w_attn_out
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
,
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -247,9 +244,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -247,9 +244,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
}
// 2. FFN
// 2. FFN
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
rmsnorm
(
logits_out
,
logits_in
,
rsrc
.
w_ffn_norm
[
layer
],
meta
.
epsilon
);
linear
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
,
nullptr
);
linear
(
gate_up_buf
,
logits_out
,
rsrc
.
w_ffn_gate_up
[
layer
],
1.0
,
0.0
,
nullptr
,
nullptr
);
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
swiglu
(
gate_buf
,
up_buf
,
gate_buf
);
linear
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
);
// only rank 0 adds residual
linear
(
logits_in
,
gate_buf
,
rsrc
.
w_ffn_down
[
layer
],
1.0
,
0.0
,
idev
==
0
?
logits_in
:
nullptr
,
nullptr
);
// only rank 0 adds residual
// All_reduce if distributed
// All_reduce if distributed
if
(
rsrc
.
comm
!=
nullptr
)
{
if
(
rsrc
.
comm
!=
nullptr
)
{
...
@@ -270,15 +267,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
...
@@ -270,15 +267,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rsrc
.
w_out_norm
,
rsrc
.
w_out_norm
,
meta
.
epsilon
);
meta
.
epsilon
);
}
}
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
);
linear
(
prob_buf
,
logits_out
->
slice
(
0
,
0
,
nreq
),
rsrc
.
w_out_embd
,
1.0
,
0.0
,
nullptr
,
nullptr
);
std
::
random_device
_rd
;
std
::
random_device
_rd
;
std
::
mt19937
gen
(
_rd
());
std
::
mt19937
gen
(
_rd
());
token_offset
=
0
;
token_offset
=
0
;
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
for
(
uint32_t
req
=
0
;
req
<
nreq
;
req
++
)
{
auto
seq_len
=
req_lens
[
req
];
auto
seq_len
=
req_lens
[
req
];
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
float
random_val
=
std
::
uniform_real_distribution
<
float
>
(
0
,
1
)(
gen
);
randomSample
(
result_buf
->
view
({},
{}
),
randomSample
(
result_buf
->
memShare
({},
result_buf
->
dtype
()
),
prob_buf
->
view
({
dvoc
},
{
1
}),
prob_buf
->
view
_as
({
dvoc
},
{
1
}),
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
random_val
,
topp
[
req
],
topk
[
req
],
temperature
[
req
]);
token_offset
+=
seq_len
;
token_offset
+=
seq_len
;
}
}
...
...
src/tensor.hpp
View file @
6998a8f1
...
@@ -128,9 +128,8 @@ public:
...
@@ -128,9 +128,8 @@ public:
void
debug
()
const
;
void
debug
()
const
;
std
::
string
info
()
const
;
std
::
string
info
()
const
;
std
::
shared_ptr
<
Tensor
>
view
()
const
;
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
;
std
::
shared_ptr
<
Tensor
>
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
&
new_strides
)
const
;
std
::
shared_ptr
<
Tensor
>
viewReshaped
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
;
~
Tensor
();
~
Tensor
();
};
};
...
...
src/tensor/tensor.cpp
View file @
6998a8f1
...
@@ -258,23 +258,7 @@ std::string Tensor::info() const {
...
@@ -258,23 +258,7 @@ std::string Tensor::info() const {
return
this
->
_desc
->
info
();
return
this
->
_desc
->
info
();
}
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
()
const
{
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
this
->
shape
(),
this
->
strides
());
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view
(
const
std
::
vector
<
size_t
>
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
new_strides
)
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
new_strides
);
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
viewReshaped
(
const
std
::
vector
<
size_t
>
&
new_shape
)
const
{
// Calculate total elements in current and new shape
// Calculate total elements in current and new shape
size_t
current_elements
=
std
::
accumulate
(
size_t
current_elements
=
std
::
accumulate
(
_desc
->
shape
().
begin
(),
_desc
->
shape
().
end
(),
_desc
->
shape
().
begin
(),
_desc
->
shape
().
end
(),
...
@@ -340,6 +324,14 @@ std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shap
...
@@ -340,6 +324,14 @@ std::shared_ptr<Tensor> Tensor::viewReshaped(const std::vector<size_t> &new_shap
return
result
;
return
result
;
}
}
std
::
shared_ptr
<
Tensor
>
Tensor
::
view_as
(
const
std
::
vector
<
size_t
>
&
new_shape
,
const
std
::
vector
<
ptrdiff_t
>
&
new_strides
)
const
{
std
::
shared_ptr
<
Tensor
>
tensor
=
std
::
make_shared
<
Tensor
>
();
tensor
->
_storage
=
this
->
_storage
;
tensor
->
_desc
=
TensorDesc
::
create
(
this
->
dtype
(),
new_shape
,
new_strides
);
tensor
->
_offset
=
this
->
_offset
;
return
tensor
;
}
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
void
Tensor
::
debug
(
const
std
::
string
&
filename
)
const
{
RUN_INFINI
(
infinirtDeviceSynchronize
());
RUN_INFINI
(
infinirtDeviceSynchronize
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment