Commit c50a5d09 authored by PanZezhong's avatar PanZezhong
Browse files

fix naming

parent 81fe2ba3
...@@ -24,32 +24,32 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, ...@@ -24,32 +24,32 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
w_ffn_norm, w_ffn_gate_up, w_ffn_down; w_ffn_norm, w_ffn_gate_up, w_ffn_down;
for (size_t layer = 0; layer < meta->nlayer; layer++) { for (size_t layer = 0; layer < meta->nlayer; layer++) {
w_attn_norm.push_back( w_attn_norm.push_back(
get_attn_norm(meta, weights, layer)); getAttnNorm(meta, weights, layer));
w_attn_qkv.push_back( w_attn_qkv.push_back(
get_attn_qkv(meta, weights, layer, idev, ndev)); getAttnQKV(meta, weights, layer, idev, ndev));
if (weights->attn_qkv_b != nullptr) { if (weights->attn_qkv_b != nullptr) {
b_attn_qkv.push_back( b_attn_qkv.push_back(
get_attn_qkv_bias(meta, weights, layer, idev, ndev)); getAttnQKVBias(meta, weights, layer, idev, ndev));
} }
w_attn_out.push_back( w_attn_out.push_back(
get_attn_o(meta, weights, layer, idev, ndev)); getAttnO(meta, weights, layer, idev, ndev));
w_ffn_norm.push_back( w_ffn_norm.push_back(
get_ffn_norm(meta, weights, layer)); getFFNNorm(meta, weights, layer));
w_ffn_gate_up.push_back( w_ffn_gate_up.push_back(
get_ffn_gate_up(meta, weights, layer, idev, ndev)); getFFNGateUp(meta, weights, layer, idev, ndev));
w_ffn_down.push_back( w_ffn_down.push_back(
get_ffn_down(meta, weights, layer, idev, ndev)); getFFNDown(meta, weights, layer, idev, ndev));
} }
*rsrc = DeviceResource{device, *rsrc = DeviceResource{device,
dev_id, dev_id,
handle, handle,
get_in_embd(meta, weights), getInEmbd(meta, weights),
get_out_norm(meta, weights), getOutNorm(meta, weights),
get_out_embd(meta, weights), getOutEmbd(meta, weights),
get_sin_table(meta), getSinTable(meta),
get_cos_table(meta), getCosTable(meta),
w_attn_norm, w_attn_norm,
w_attn_qkv, w_attn_qkv,
b_attn_qkv, b_attn_qkv,
...@@ -136,7 +136,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc, ...@@ -136,7 +136,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_o, &temp_size)); RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_o, &temp_size));
workspace_size = std::max(workspace_size, temp_size); workspace_size = std::max(workspace_size, temp_size);
infiniopRoPEDescriptor_t desc_rope_q, desc_rope_k; infiniopRoPEDescriptor_t desc_rope_q, desc_rope_k;
qkv_buf->dim_split(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh) qkv_buf->dimSplit(1, {nh + nkvh * 2, dh}); // (ntok, nh + 2 * nkvh, dh)
auto qkv_buf_q = qkv_buf->slice(1, 0, nh); auto qkv_buf_q = qkv_buf->slice(1, 0, nh);
auto qkv_buf_k = qkv_buf->slice(1, nh, nkvh); auto qkv_buf_k = qkv_buf->slice(1, nh, nkvh);
RUN_INFINI(infiniopCreateRoPEDescriptor( RUN_INFINI(infiniopCreateRoPEDescriptor(
...@@ -154,7 +154,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc, ...@@ -154,7 +154,7 @@ void inferDeviceBatch(const JiugeMeta &meta, const DeviceResource &rsrc,
// attention inner // attention inner
auto desc_attns = std::vector<infiniopAttentionDescriptor_t>(nreq); auto desc_attns = std::vector<infiniopAttentionDescriptor_t>(nreq);
size_t token_offset = 0; size_t token_offset = 0;
o_buf->dim_split(1, {nh, dh}); o_buf->dimSplit(1, {nh, dh});
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
......
...@@ -32,13 +32,13 @@ __C struct KVCache *duplicateKVCache(const JiugeModel *model, ...@@ -32,13 +32,13 @@ __C struct KVCache *duplicateKVCache(const JiugeModel *model,
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
new_kv_cache->k[idev][layer] new_kv_cache->k[idev][layer]
->slice(1, 0, seq_len) ->slice(1, 0, seq_len)
->copy_from(kv_cache->k[idev][layer]->slice(1, 0, seq_len), ->copyFrom(kv_cache->k[idev][layer]->slice(1, 0, seq_len),
model->dev_resources[idev].handle); model->dev_resources[idev].handle);
new_kv_cache->v[idev][layer] new_kv_cache->v[idev][layer]
->slice(1, 0, seq_len) ->slice(1, 0, seq_len)
->copy_from(kv_cache->v[idev][layer]->slice(1, 0, seq_len), ->copyFrom(kv_cache->v[idev][layer]->slice(1, 0, seq_len),
model->dev_resources[idev].handle); model->dev_resources[idev].handle);
} }
} }
return new_kv_cache; return new_kv_cache;
......
...@@ -4,21 +4,21 @@ ...@@ -4,21 +4,21 @@
#include "jiuge_impl.hpp" #include "jiuge_impl.hpp"
#include <cmath> #include <cmath>
inline std::shared_ptr<Tensor> get_in_embd( inline std::shared_ptr<Tensor> getInEmbd(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w) { JiugeWeights const *w) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d}); auto shape = std::vector<size_t>({meta->dvoc, meta->d});
return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape); return Tensor::weight((char *)w->input_embd, meta->dt_logits, shape);
} }
inline std::shared_ptr<Tensor> get_out_norm( inline std::shared_ptr<Tensor> getOutNorm(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w) { JiugeWeights const *w) {
auto shape = std::vector<size_t>({meta->d}); auto shape = std::vector<size_t>({meta->d});
return Tensor::weight((char *)w->output_norm, meta->dt_norm, shape); return Tensor::weight((char *)w->output_norm, meta->dt_norm, shape);
} }
inline std::shared_ptr<Tensor> get_out_embd( inline std::shared_ptr<Tensor> getOutEmbd(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w) { JiugeWeights const *w) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d}); auto shape = std::vector<size_t>({meta->dvoc, meta->d});
...@@ -26,7 +26,7 @@ inline std::shared_ptr<Tensor> get_out_embd( ...@@ -26,7 +26,7 @@ inline std::shared_ptr<Tensor> get_out_embd(
->permute({1, 0}); ->permute({1, 0});
} }
inline std::shared_ptr<Tensor> get_attn_norm( inline std::shared_ptr<Tensor> getAttnNorm(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer) { size_t layer) {
...@@ -34,7 +34,7 @@ inline std::shared_ptr<Tensor> get_attn_norm( ...@@ -34,7 +34,7 @@ inline std::shared_ptr<Tensor> get_attn_norm(
return Tensor::weight((char *)(w->attn_norm[layer]), meta->dt_norm, shape); return Tensor::weight((char *)(w->attn_norm[layer]), meta->dt_norm, shape);
} }
inline std::shared_ptr<Tensor> get_attn_qkv( inline std::shared_ptr<Tensor> getAttnQKV(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer, size_t idev, size_t ndev) { size_t layer, size_t idev, size_t ndev) {
...@@ -48,7 +48,7 @@ inline std::shared_ptr<Tensor> get_attn_qkv( ...@@ -48,7 +48,7 @@ inline std::shared_ptr<Tensor> get_attn_qkv(
->permute({1, 0}); ->permute({1, 0});
} }
inline std::shared_ptr<Tensor> get_attn_qkv_bias( inline std::shared_ptr<Tensor> getAttnQKVBias(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer, size_t idev, size_t ndev) { size_t layer, size_t idev, size_t ndev) {
...@@ -60,9 +60,9 @@ inline std::shared_ptr<Tensor> get_attn_qkv_bias( ...@@ -60,9 +60,9 @@ inline std::shared_ptr<Tensor> get_attn_qkv_bias(
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, meta->dt_mat, shape); return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, meta->dt_mat, shape);
} }
inline std::shared_ptr<Tensor> get_attn_o(JiugeMeta const *meta, inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
JiugeWeights const *w, size_t layer, JiugeWeights const *w, size_t layer,
size_t idev, size_t ndev) { size_t idev, size_t ndev) {
auto nh = meta->nh; auto nh = meta->nh;
auto dh = meta->dh; auto dh = meta->dh;
auto d = meta->d; auto d = meta->d;
...@@ -72,7 +72,7 @@ inline std::shared_ptr<Tensor> get_attn_o(JiugeMeta const *meta, ...@@ -72,7 +72,7 @@ inline std::shared_ptr<Tensor> get_attn_o(JiugeMeta const *meta,
->permute({1, 0}); ->permute({1, 0});
} }
inline std::shared_ptr<Tensor> get_ffn_norm( inline std::shared_ptr<Tensor> getFFNNorm(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer) { size_t layer) {
...@@ -80,7 +80,7 @@ inline std::shared_ptr<Tensor> get_ffn_norm( ...@@ -80,7 +80,7 @@ inline std::shared_ptr<Tensor> get_ffn_norm(
return Tensor::weight((char *)(w->ffn_norm[layer]), meta->dt_norm, shape); return Tensor::weight((char *)(w->ffn_norm[layer]), meta->dt_norm, shape);
} }
inline std::shared_ptr<Tensor> get_ffn_gate_up( inline std::shared_ptr<Tensor> getFFNGateUp(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer, size_t idev, size_t ndev) { size_t layer, size_t idev, size_t ndev) {
...@@ -93,7 +93,7 @@ inline std::shared_ptr<Tensor> get_ffn_gate_up( ...@@ -93,7 +93,7 @@ inline std::shared_ptr<Tensor> get_ffn_gate_up(
->permute({1, 0}); ->permute({1, 0});
} }
inline std::shared_ptr<Tensor> get_ffn_down( inline std::shared_ptr<Tensor> getFFNDown(
JiugeMeta const *meta, JiugeMeta const *meta,
JiugeWeights const *w, JiugeWeights const *w,
size_t layer, size_t idev, size_t ndev) { size_t layer, size_t idev, size_t ndev) {
...@@ -105,7 +105,7 @@ inline std::shared_ptr<Tensor> get_ffn_down( ...@@ -105,7 +105,7 @@ inline std::shared_ptr<Tensor> get_ffn_down(
->permute({1, 0}); ->permute({1, 0});
} }
inline std::shared_ptr<Tensor> get_sin_table(JiugeMeta const *meta) { inline std::shared_ptr<Tensor> getSinTable(JiugeMeta const *meta) {
float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float)); float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float));
auto half_dh = meta->dh / 2; auto half_dh = meta->dh / 2;
for (size_t i = 0; i < meta->dctx; i++) { for (size_t i = 0; i < meta->dctx; i++) {
...@@ -122,7 +122,7 @@ inline std::shared_ptr<Tensor> get_sin_table(JiugeMeta const *meta) { ...@@ -122,7 +122,7 @@ inline std::shared_ptr<Tensor> get_sin_table(JiugeMeta const *meta) {
return tensor; return tensor;
} }
inline std::shared_ptr<Tensor> get_cos_table(JiugeMeta const *meta) { inline std::shared_ptr<Tensor> getCosTable(JiugeMeta const *meta) {
float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float)); float *table = (float *)std::malloc(meta->dctx * meta->dh * sizeof(float));
auto half_dh = meta->dh / 2; auto half_dh = meta->dh / 2;
for (size_t i = 0; i < meta->dctx; i++) { for (size_t i = 0; i < meta->dctx; i++) {
......
...@@ -48,9 +48,9 @@ private: ...@@ -48,9 +48,9 @@ private:
std::shared_ptr<Storage> storage; std::shared_ptr<Storage> storage;
infiniopTensorDescriptor_t _desc; infiniopTensorDescriptor_t _desc;
void *data_impl(ptrdiff_t offset) const; void *dataImpl(ptrdiff_t offset) const;
std::shared_ptr<Tensor> std::shared_ptr<Tensor>
slice_impl(const std::vector<SliceParams> &slices) const; sliceImpl(const std::vector<SliceParams> &slices) const;
public: public:
static std::shared_ptr<Tensor> buffer(infiniDtype_t dtype, static std::shared_ptr<Tensor> buffer(infiniDtype_t dtype,
...@@ -65,23 +65,23 @@ public: ...@@ -65,23 +65,23 @@ public:
std::shared_ptr<Tensor> slice(const std::vector<SliceParams> &slices); std::shared_ptr<Tensor> slice(const std::vector<SliceParams> &slices);
std::shared_ptr<Tensor const> std::shared_ptr<Tensor const>
slice(const std::vector<SliceParams> &slices) const; slice(const std::vector<SliceParams> &slices) const;
std::shared_ptr<Tensor> dim_merge(size_t dim_start, size_t dim_end); std::shared_ptr<Tensor> dimMerge(size_t dim_start, size_t dim_end);
std::shared_ptr<Tensor> dim_split(size_t dim, std::shared_ptr<Tensor> dimSplit(size_t dim,
const std::vector<size_t> &dims); const std::vector<size_t> &dims);
std::shared_ptr<Tensor> permute(const std::vector<size_t> &order); std::shared_ptr<Tensor> permute(const std::vector<size_t> &order);
void *data(ptrdiff_t offset = 0); void *data(ptrdiff_t offset = 0);
void const *data(ptrdiff_t offset = 0) const; void const *data(ptrdiff_t offset = 0) const;
void copy_from(std::shared_ptr<Tensor const> src, infiniopHandle_t handle, void copyFrom(std::shared_ptr<Tensor const> src, infiniopHandle_t handle,
infinirtStream_t stream = nullptr); infinirtStream_t stream = nullptr);
const std::vector<size_t> &shape() const; const std::vector<size_t> &shape() const;
const std::vector<ptrdiff_t> &strides() const; const std::vector<ptrdiff_t> &strides() const;
size_t ndim() const; size_t ndim() const;
infiniDtype_t dtype() const; infiniDtype_t dtype() const;
std::shared_ptr<TensorDesc> desc() const; std::shared_ptr<TensorDesc> desc() const;
size_t byte_size() const; size_t byteSize() const;
ptrdiff_t data_offset() const; ptrdiff_t dataOffset() const;
infiniDevice_t device_type() const; infiniDevice_t deviceType() const;
int device_id() const; int deviceId() const;
bool is_contigous() const; bool is_contigous() const;
void debug(const std::string &filename) const; void debug(const std::string &filename) const;
......
...@@ -21,12 +21,12 @@ const std::vector<size_t> &Tensor::shape() const { return this->_shape; } ...@@ -21,12 +21,12 @@ const std::vector<size_t> &Tensor::shape() const { return this->_shape; }
const std::vector<ptrdiff_t> &Tensor::strides() const { return this->_strides; } const std::vector<ptrdiff_t> &Tensor::strides() const { return this->_strides; }
size_t Tensor::ndim() const { return this->_shape.size(); } size_t Tensor::ndim() const { return this->_shape.size(); }
infiniDtype_t Tensor::dtype() const { return this->_dtype; } infiniDtype_t Tensor::dtype() const { return this->_dtype; }
size_t Tensor::byte_size() const { return this->_size; } size_t Tensor::byteSize() const { return this->_size; }
infiniDevice_t Tensor::device_type() const { return this->storage->device_type; } infiniDevice_t Tensor::deviceType() const { return this->storage->device_type; }
int Tensor::device_id() const { return this->storage->device_id; } int Tensor::deviceId() const { return this->storage->device_id; }
Tensor::~Tensor() {} Tensor::~Tensor() {}
ptrdiff_t Tensor::data_offset() const { ptrdiff_t Tensor::dataOffset() const {
return (char *)(this->_data) - (char *)(this->storage->memory); return (char *)(this->_data) - (char *)(this->storage->memory);
} }
...@@ -90,22 +90,22 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype, ...@@ -90,22 +90,22 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
return tensor; return tensor;
} }
void *Tensor::data_impl(ptrdiff_t offset) const { void *Tensor::dataImpl(ptrdiff_t offset) const {
ASSERT(offset * dsize(this->dtype()) < this->_size); ASSERT(offset * dsize(this->dtype()) < this->_size);
return (char *)(this->_data) + offset * dsize(this->dtype()); return (char *)(this->_data) + offset * dsize(this->dtype());
} }
void *Tensor::data(ptrdiff_t offset) { void *Tensor::data(ptrdiff_t offset) {
return this->data_impl(offset); return this->dataImpl(offset);
} }
const void *Tensor::data(ptrdiff_t offset) const { const void *Tensor::data(ptrdiff_t offset) const {
return this->data_impl(offset); return this->dataImpl(offset);
} }
void Tensor::copy_from(std::shared_ptr<Tensor const> src, void Tensor::copyFrom(std::shared_ptr<Tensor const> src,
infiniopHandle_t handle, infinirtStream_t stream) { infiniopHandle_t handle, infinirtStream_t stream) {
ASSERT_EQ(this->shape(), src->shape()); ASSERT_EQ(this->shape(), src->shape());
ASSERT_EQ(this->dtype(), src->dtype()); ASSERT_EQ(this->dtype(), src->dtype());
infiniopRearrangeDescriptor_t desc; infiniopRearrangeDescriptor_t desc;
...@@ -172,11 +172,11 @@ void Tensor::debug(const std::string &filename) const { ...@@ -172,11 +172,11 @@ void Tensor::debug(const std::string &filename) const {
std::cout << s << " "; std::cout << s << " ";
} }
std::cout << "] dtype=" << this->dtype() std::cout << "] dtype=" << this->dtype()
<< " device=" << this->device_type() << " device=" << this->deviceType()
<< " device_id=" << this->device_id() << std::endl; << " device_id=" << this->deviceId() << std::endl;
auto dtype = this->dtype(); auto dtype = this->dtype();
void const *cpu_data; void const *cpu_data;
if (this->device_type() != INFINI_DEVICE_CPU) { if (this->deviceType() != INFINI_DEVICE_CPU) {
void *cpu_memory = std::malloc(this->storage->size); void *cpu_memory = std::malloc(this->storage->size);
RUN_INFINI(infinirtMemcpy(cpu_memory, this->storage->memory, RUN_INFINI(infinirtMemcpy(cpu_memory, this->storage->memory,
this->storage->size, INFINIRT_MEMCPY_D2H)); this->storage->size, INFINIRT_MEMCPY_D2H));
...@@ -199,27 +199,27 @@ void Tensor::debug(const std::string &filename) const { ...@@ -199,27 +199,27 @@ void Tensor::debug(const std::string &filename) const {
switch (dtype) { switch (dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
print_data((uint16_t const *)((char const *)cpu_data + data_offset()), print_data((uint16_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
print_data((float const *)((char const *)cpu_data + data_offset()), print_data((float const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_U64: case INFINI_DTYPE_U64:
print_data((uint64_t const *)((char const *)cpu_data + data_offset()), print_data((uint64_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_I64: case INFINI_DTYPE_I64:
print_data((int64_t const *)((char const *)cpu_data + data_offset()), print_data((int64_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_U32: case INFINI_DTYPE_U32:
print_data((uint32_t const *)((char const *)cpu_data + data_offset()), print_data((uint32_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
case INFINI_DTYPE_I32: case INFINI_DTYPE_I32:
print_data((int32_t const *)((char const *)cpu_data + data_offset()), print_data((int32_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0); this->shape(), this->strides(), 0);
break; break;
default: default:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <numeric> #include <numeric>
#include <vector> #include <vector>
std::shared_ptr<Tensor> Tensor::slice_impl(const std::vector<SliceParams> &slices) const { std::shared_ptr<Tensor> Tensor::sliceImpl(const std::vector<SliceParams> &slices) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>(); std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
auto new_shape = std::vector<size_t>(this->_shape); auto new_shape = std::vector<size_t>(this->_shape);
...@@ -32,22 +32,22 @@ std::shared_ptr<Tensor> Tensor::slice_impl(const std::vector<SliceParams> &slice ...@@ -32,22 +32,22 @@ std::shared_ptr<Tensor> Tensor::slice_impl(const std::vector<SliceParams> &slice
} }
std::shared_ptr<Tensor> Tensor::slice(size_t dim, size_t start, size_t len) { std::shared_ptr<Tensor> Tensor::slice(size_t dim, size_t start, size_t len) {
return this->slice_impl({{dim, start, len}}); return this->sliceImpl({{dim, start, len}});
} }
std::shared_ptr<Tensor const> Tensor::slice(size_t dim, size_t start, size_t len) const { std::shared_ptr<Tensor const> Tensor::slice(size_t dim, size_t start, size_t len) const {
return this->slice_impl({{dim, start, len}}); return this->sliceImpl({{dim, start, len}});
} }
std::shared_ptr<Tensor> Tensor::slice(const std::vector<SliceParams> &slices) { std::shared_ptr<Tensor> Tensor::slice(const std::vector<SliceParams> &slices) {
return this->slice_impl(slices); return this->sliceImpl(slices);
} }
std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slices) const { std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slices) const {
return this->slice_impl(slices); return this->sliceImpl(slices);
} }
std::shared_ptr<Tensor> Tensor::dim_merge(size_t dim_start, size_t dim_end) { std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
ASSERT(dim_start <= dim_end && dim_end < this->_shape.size()); ASSERT(dim_start <= dim_end && dim_end < this->_shape.size());
if (dim_start == dim_end) { if (dim_start == dim_end) {
return shared_from_this(); return shared_from_this();
...@@ -77,7 +77,7 @@ std::shared_ptr<Tensor> Tensor::dim_merge(size_t dim_start, size_t dim_end) { ...@@ -77,7 +77,7 @@ std::shared_ptr<Tensor> Tensor::dim_merge(size_t dim_start, size_t dim_end) {
return shared_from_this(); return shared_from_this();
} }
std::shared_ptr<Tensor> Tensor::dim_split(size_t dim, const std::vector<size_t> &dims) { std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>())); ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()));
auto new_shape = std::vector<size_t>(); auto new_shape = std::vector<size_t>();
auto new_strides = std::vector<ptrdiff_t>(); auto new_strides = std::vector<ptrdiff_t>();
......
...@@ -6,16 +6,16 @@ ...@@ -6,16 +6,16 @@
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
inline void assert_true(int expr, const char *msg, const char *file, int line) { inline void assertTrue(int expr, const char *msg, const char *file, int line) {
if (!expr) { if (!expr) {
fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line); fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
} }
#define ASSERT(expr) assert_true((expr), #expr " is false", __FILE__, __LINE__) #define ASSERT(expr) assertTrue((expr), #expr " is false", __FILE__, __LINE__)
#define ASSERT_EQ(a, b) assert_true((a) == (b), #a " != " #b, __FILE__, __LINE__) #define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __FILE__, __LINE__)
#define ASSERT_VALID_PTR(a) assert_true((a) != nullptr, #a " is nullptr", __FILE__, __LINE__) #define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __FILE__, __LINE__)
#define PANIC(EXPR) \ #define PANIC(EXPR) \
printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \ printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment