Commit 1e30fac0 authored by PanZezhong's avatar PanZezhong
Browse files

refactor tensor desc

parent 7a087fdc
......@@ -169,8 +169,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
// attn & mlp rmsnorm
infiniopRMSNormDescriptor_t desc_norm;
RUN_INFINI(infiniopCreateRMSNormDescriptor(
rsrc.handle, &desc_norm, logits_in->desc()->get(),
logits_out->desc()->get(), rsrc.w_attn_norm[0]->desc()->get(),
rsrc.handle, &desc_norm, logits_in->desc(),
logits_out->desc(), rsrc.w_attn_norm[0]->desc(),
meta.epsilon));
RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm, &workspace_size));
workspace_size = std::max(workspace_size, temp_size);
......@@ -179,15 +179,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
infiniopRearrangeDescriptor_t desc_qkv_bias;
if (has_qkv_bias) {
RUN_INFINI(infiniopCreateRearrangeDescriptor(
rsrc.handle, &desc_qkv_bias, qkv_buf->desc()->get(),
TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->get()));
rsrc.handle, &desc_qkv_bias, qkv_buf->desc(),
TensorDesc::create(dt_logits, {ntok, (nh + nkvh * 2) * dh}, {0, 1})->desc()));
}
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_qkv, qkv_buf->desc()->get(),
logits_in->desc()->get(), rsrc.w_attn_qkv[0]->desc()->get()));
rsrc.handle, &desc_attn_qkv, qkv_buf->desc(),
logits_in->desc(), rsrc.w_attn_qkv[0]->desc()));
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_o, logits_in->desc()->get(),
o_buf->desc()->get(), rsrc.w_attn_out[0]->desc()->get()));
rsrc.handle, &desc_attn_o, logits_in->desc(),
o_buf->desc(), rsrc.w_attn_out[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_qkv, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_o, &temp_size));
......@@ -197,15 +197,15 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto qkv_buf_q = qkv_buf->slice(1, 0, nh);
auto qkv_buf_k = qkv_buf->slice(1, nh, nkvh);
RUN_INFINI(infiniopCreateRoPEDescriptor(
rsrc.handle, &desc_rope_q, qkv_buf_q->desc()->get(), qkv_buf_q->desc()->get(),
pos_ids_buf->desc()->get(), rsrc.sin_table->desc()->get(),
rsrc.cos_table->desc()->get()));
rsrc.handle, &desc_rope_q, qkv_buf_q->desc(), qkv_buf_q->desc(),
pos_ids_buf->desc(), rsrc.sin_table->desc(),
rsrc.cos_table->desc()));
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_q, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopCreateRoPEDescriptor(
rsrc.handle, &desc_rope_k, qkv_buf_k->desc()->get(), qkv_buf_k->desc()->get(),
pos_ids_buf->desc()->get(), rsrc.sin_table->desc()->get(),
rsrc.cos_table->desc()->get()));
rsrc.handle, &desc_rope_k, qkv_buf_k->desc(), qkv_buf_k->desc(),
pos_ids_buf->desc(), rsrc.sin_table->desc(),
rsrc.cos_table->desc()));
RUN_INFINI(infiniopGetRoPEWorkspaceSize(desc_rope_k, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// attention inner
......@@ -233,38 +233,38 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto cache_kv = kv_caches[req]->k[idev][0]->slice(0, past_len, seq_len);
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
cache_kv->desc()->get(), k->desc()->get()));
cache_kv->desc(), k->desc()));
// [nkvh, ngroup, seq_len, dh]
q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
auto q_t = TensorDesc::create(dt_logits, {nkvh, ngroup, seq_len, dh});
// [seq_len, nkvh, ngroup, dh] -> [nkvh, ngroup, seq_len, dh]
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_q_rearranges[req],
q_t->get(), q->desc()->get()));
q_t->desc(), q->desc()));
// [nkvh, ngroup, seq_len, dh] -> [seq_len, nkvh, ngroup, dh]
auto attn_v_t = q_t;
auto attn_v = TensorDesc::createWithOrder(dt_logits, {nkvh, ngroup, seq_len, dh}, {1, 2, 0, 3});
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_attn_v_rearranges[req],
attn_v->get(), attn_v_t->get()));
attn_v->desc(), attn_v_t->desc()));
q_t = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, dh});
auto qk = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
max_seq_len = std::max(max_seq_len, size_t(seq_len));
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_qk_gemms[req], qk->get(), q_t->get(), full_kv->desc()->get()));
rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), full_kv->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_qk_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// [nkvh, total_len, dh]
auto full_v = kv_caches[req]->v[idev][0]->slice(0, 0, total_len)->permute({1, 0, 2});
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_attn_v_gemms[req], q_t->get(), qk->get(), full_v->desc()->get()));
rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), full_v->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_v_gemms[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
qk = TensorDesc::create(dt_logits, {nkvh * ngroup, seq_len, total_len});
RUN_INFINI(infiniopCreateCausalSoftmaxDescriptor(
rsrc.handle, &desc_qk_softmaxs[req], qk->get(), qk->get()));
rsrc.handle, &desc_qk_softmaxs[req], qk->desc(), qk->desc()));
RUN_INFINI(infiniopGetCausalSoftmaxWorkspaceSize(desc_qk_softmaxs[req], &temp_size));
workspace_size = std::max(workspace_size, temp_size);
......@@ -278,47 +278,47 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
infiniopGemmDescriptor_t desc_ffn_gate_up, desc_ffn_down;
infiniopSwiGLUDescriptor_t desc_swiglu;
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_ffn_gate_up, gate_up_buf->desc()->get(),
logits_out->desc()->get(), rsrc.w_ffn_gate_up[0]->desc()->get()));
rsrc.handle, &desc_ffn_gate_up, gate_up_buf->desc(),
logits_out->desc(), rsrc.w_ffn_gate_up[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_gate_up, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
auto gate_buf = gate_up_buf->slice(1, 0, di);
auto up_buf = gate_up_buf->slice(1, di, di);
RUN_INFINI(infiniopCreateSwiGLUDescriptor(
rsrc.handle, &desc_swiglu, gate_buf->desc()->get(), up_buf->desc()->get(), gate_buf->desc()->get()));
rsrc.handle, &desc_swiglu, gate_buf->desc(), up_buf->desc(), gate_buf->desc()));
RUN_INFINI(infiniopGetSwiGLUWorkspaceSize(desc_swiglu, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_ffn_down, logits_in->desc()->get(),
gate_buf->desc()->get(), rsrc.w_ffn_down[0]->desc()->get()));
rsrc.handle, &desc_ffn_down, logits_in->desc(),
gate_buf->desc(), rsrc.w_ffn_down[0]->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_ffn_down, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// Output and sample
infiniopRMSNormDescriptor_t desc_norm_out;
RUN_INFINI(infiniopCreateRMSNormDescriptor(
rsrc.handle, &desc_norm_out, logits_out->slice(0, 0, 1)->desc()->get(),
logits_out->slice(0, 0, 1)->desc()->get(),
rsrc.w_out_norm->desc()->get(), meta.epsilon));
rsrc.handle, &desc_norm_out, logits_out->slice(0, 0, 1)->desc(),
logits_out->slice(0, 0, 1)->desc(),
rsrc.w_out_norm->desc(), meta.epsilon));
RUN_INFINI(infiniopGetRMSNormWorkspaceSize(desc_norm_out, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
infiniopGemmDescriptor_t desc_out_embd;
RUN_INFINI(infiniopCreateGemmDescriptor(
rsrc.handle, &desc_out_embd, prob_buf->desc()->get(),
logits_out->slice(0, 0, nreq)->desc()->get(),
rsrc.w_out_embd->desc()->get()));
rsrc.handle, &desc_out_embd, prob_buf->desc(),
logits_out->slice(0, 0, nreq)->desc(),
rsrc.w_out_embd->desc()));
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_out_embd, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
infiniopRandomSampleDescriptor_t desc_sample;
RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc.handle, &desc_sample,
TensorDesc::create(INFINI_DTYPE_I64, {}, {})->get(),
TensorDesc::create(dt_logits, {dvoc}, {1})->get()));
TensorDesc::create(INFINI_DTYPE_I64, {}, {})->desc(),
TensorDesc::create(dt_logits, {dvoc}, {1})->desc()));
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
// Allocate workspace
std::shared_ptr<Storage> workspace_storage = Storage::createFromPool(workspace_size, rsrc.memory_pool);
void *workspace = workspace_storage->memory;
void *workspace = workspace_storage->memory();
// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
......
......@@ -11,19 +11,23 @@
class Storage {
private:
Storage() = default;
void *_memory;
size_t _size;
infiniDevice_t _device_type;
int _device_id;
std::shared_ptr<MemoryPool> _memory_pool;
public:
void *memory;
size_t size;
infiniDevice_t device_type;
int device_id;
std::shared_ptr<MemoryPool> memory_pool;
static std::shared_ptr<Storage> create(size_t size);
static std::shared_ptr<Storage> createAsync(size_t size, infinirtStream_t stream = nullptr);
static std::shared_ptr<Storage> createFromPool(size_t size, std::shared_ptr<MemoryPool> pool = nullptr);
static std::shared_ptr<Storage> createHost(size_t size);
~Storage();
void *memory() const { return _memory; }
size_t size() const { return _size; }
infiniDevice_t deviceType() const { return _device_type; }
int deviceId() const { return _device_id; }
};
struct SliceParams {
......@@ -43,9 +47,17 @@ std::vector<ptrdiff_t> __strides(Args... args) {
}
class TensorDesc {
private:
infiniDtype_t _dtype;
std::vector<size_t> _shape;
std::vector<ptrdiff_t> _strides;
infiniopTensorDescriptor_t _desc;
TensorDesc(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides) : _dtype(dtype), _shape(shape), _strides(strides), _desc(nullptr) {}
void resetDesc();
public:
~TensorDesc();
static std::shared_ptr<TensorDesc>
create(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides);
......@@ -54,19 +66,26 @@ public:
static std::shared_ptr<TensorDesc>
createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<size_t> &order);
infiniopTensorDescriptor_t get() const { return _desc; };
~TensorDesc();
infiniDtype_t dtype() const { return _dtype; }
const std::vector<size_t> &shape() const { return _shape; }
const std::vector<ptrdiff_t> &strides() const { return _strides; }
size_t ndim() const { return _shape.size(); }
infiniopTensorDescriptor_t desc() const;
bool isContigous() const;
std::string info() const;
void dimMerge(size_t dim_start, size_t dim_end);
void dimSplit(size_t dim, const std::vector<size_t> &dims);
void permute(const std::vector<size_t> &order);
};
class Tensor : public std::enable_shared_from_this<Tensor> {
private:
infiniDtype_t _dtype;
std::vector<size_t> _shape;
std::vector<ptrdiff_t> _strides;
void *_data;
ptrdiff_t _offset;
std::shared_ptr<Storage> _storage;
infiniopTensorDescriptor_t _desc;
std::shared_ptr<TensorDesc> _desc;
ptrdiff_t _offset;
void *dataImpl(ptrdiff_t offset) const;
std::shared_ptr<Tensor>
......@@ -99,11 +118,11 @@ public:
const std::vector<ptrdiff_t> &strides() const;
size_t ndim() const;
infiniDtype_t dtype() const;
std::shared_ptr<TensorDesc> desc() const;
bool isContigous() const;
infiniopTensorDescriptor_t desc() const;
ptrdiff_t dataOffset() const;
infiniDevice_t deviceType() const;
int deviceId() const;
bool is_contigous() const;
void debug(const std::string &filename) const;
void debug() const;
......
......@@ -3,51 +3,51 @@
std::shared_ptr<Storage> Storage::create(size_t size) {
auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMalloc(&storage->memory, size));
storage->size = size;
RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id));
RUN_INFINI(infinirtMalloc(&storage->_memory, size));
storage->_size = size;
RUN_INFINI(infinirtGetDevice(&storage->_device_type, &storage->_device_id));
return storage;
}
std::shared_ptr<Storage> Storage::createAsync(size_t size, infinirtStream_t stream) {
auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMallocAsync(&storage->memory, size, stream));
storage->size = size;
RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id));
RUN_INFINI(infinirtMallocAsync(&storage->_memory, size, stream));
storage->_size = size;
RUN_INFINI(infinirtGetDevice(&storage->_device_type, &storage->_device_id));
return storage;
}
std::shared_ptr<Storage> Storage::createFromPool(size_t size, std::shared_ptr<MemoryPool> pool) {
auto storage = std::shared_ptr<Storage>(new Storage());
storage->memory_pool = pool;
storage->_memory_pool = pool;
if (pool) {
storage->memory = pool->alloc(size);
storage->_memory = pool->alloc(size);
} else {
RUN_INFINI(infinirtMalloc(&storage->memory, size));
RUN_INFINI(infinirtMalloc(&storage->_memory, size));
}
storage->size = size;
RUN_INFINI(infinirtGetDevice(&storage->device_type, &storage->device_id));
storage->_size = size;
RUN_INFINI(infinirtGetDevice(&storage->_device_type, &storage->_device_id));
return storage;
}
std::shared_ptr<Storage> Storage::createHost(size_t size) {
auto storage = std::shared_ptr<Storage>(new Storage());
RUN_INFINI(infinirtMallocHost(&storage->memory, size));
storage->size = size;
storage->device_type = INFINI_DEVICE_CPU;
storage->device_id = 0;
storage->memory_pool = nullptr; // No pool for host memory
RUN_INFINI(infinirtMallocHost(&storage->_memory, size));
storage->_size = size;
storage->_device_type = INFINI_DEVICE_CPU;
storage->_device_id = 0;
storage->_memory_pool = nullptr; // No pool for host memory
return storage;
}
Storage::~Storage() {
if (memory_pool) {
memory_pool->release(memory);
if (_memory_pool) {
_memory_pool->release(_memory);
} else {
if (device_type == INFINI_DEVICE_CPU) {
RUN_INFINI(infinirtFreeHost(memory));
if (_device_type == INFINI_DEVICE_CPU) {
RUN_INFINI(infinirtFreeHost(_memory));
} else {
RUN_INFINI(infinirtFree(memory));
RUN_INFINI(infinirtFree(_memory));
}
}
}
......@@ -9,10 +9,7 @@
std::shared_ptr<TensorDesc>
TensorDesc::create(infiniDtype_t dtype, const std::vector<size_t> &shape,
const std::vector<ptrdiff_t> &strides) {
auto desc = std::make_shared<TensorDesc>();
infiniopCreateTensorDescriptor(&desc->_desc, shape.size(), shape.data(),
strides.data(), dtype);
return desc;
return std::shared_ptr<TensorDesc>(new TensorDesc(dtype, shape, strides));
}
std::shared_ptr<TensorDesc>
......@@ -48,31 +45,74 @@ TensorDesc::createWithOrder(infiniDtype_t dtype, const std::vector<size_t> &shap
return create(dtype, shape, strides);
}
infiniopTensorDescriptor_t TensorDesc::desc() const {
if (_desc == nullptr) {
RUN_INFINI(infiniopCreateTensorDescriptor(
(infiniopTensorDescriptor_t *)(&_desc), _shape.size(), _shape.data(),
_strides.data(), _dtype));
}
return _desc;
};
void TensorDesc::resetDesc() {
if (this->_desc != nullptr) {
infiniopDestroyTensorDescriptor(this->_desc);
this->_desc = nullptr;
}
}
bool TensorDesc::isContigous() const {
auto ndim = this->ndim();
auto shape = this->shape();
auto strides = std::vector<ptrdiff_t>(ndim);
strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
ASSERT_EQ(strides.size(), this->_strides.size());
return std::equal(strides.begin(), strides.end(), this->_strides.begin());
}
std::string TensorDesc::info() const {
std::stringstream ss;
ss << "Tensor: "
<< "shape[ ";
for (auto s : this->shape()) {
ss << s << " ";
}
ss << "] strides[ ";
for (auto s : this->strides()) {
ss << s << " ";
}
ss << "] dtype=" << this->dtype();
return ss.str();
}
TensorDesc::~TensorDesc() {
infiniopDestroyTensorDescriptor(this->_desc);
this->resetDesc();
}
const std::vector<size_t> &Tensor::shape() const { return this->_shape; }
const std::vector<ptrdiff_t> &Tensor::strides() const { return this->_strides; }
size_t Tensor::ndim() const { return this->_shape.size(); }
infiniDtype_t Tensor::dtype() const { return this->_dtype; }
infiniDevice_t Tensor::deviceType() const { return this->_storage->device_type; }
int Tensor::deviceId() const { return this->_storage->device_id; }
const std::vector<size_t> &Tensor::shape() const { return this->_desc->shape(); }
const std::vector<ptrdiff_t> &Tensor::strides() const { return this->_desc->strides(); }
size_t Tensor::ndim() const { return this->_desc->ndim(); }
infiniDtype_t Tensor::dtype() const { return this->_desc->dtype(); }
infiniDevice_t Tensor::deviceType() const { return this->_storage->deviceType(); }
int Tensor::deviceId() const { return this->_storage->deviceId(); }
Tensor::~Tensor() {}
ptrdiff_t Tensor::dataOffset() const {
return _offset;
}
std::shared_ptr<TensorDesc> Tensor::desc() const { return TensorDesc::create(this->_dtype, this->_shape, this->_strides); }
infiniopTensorDescriptor_t Tensor::desc() const { return _desc->desc(); }
std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
const std::vector<size_t> &shape,
std::shared_ptr<MemoryPool> pool) {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_dtype = dtype;
auto ndim = shape.size();
tensor->_shape = std::vector<size_t>(shape);
size_t size = std::accumulate(shape.begin(), shape.end(), dsize(dtype), std::multiplies<size_t>());
auto strides = std::vector<ptrdiff_t>(ndim);
......@@ -82,11 +122,8 @@ std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
strides[i] = strides[i + 1] * shape[i + 1];
}
}
tensor->_strides = strides;
tensor->_storage = Storage::createFromPool(size, pool);
tensor->_data = tensor->_storage->memory;
infiniopCreateTensorDescriptor(&tensor->_desc, ndim, tensor->_shape.data(),
strides.data(), dtype);
tensor->_desc = TensorDesc::create(dtype, shape, strides);
tensor->_offset = 0;
return tensor;
}
......@@ -94,9 +131,7 @@ std::shared_ptr<Tensor> Tensor::buffer(infiniDtype_t dtype,
std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
const std::vector<size_t> &shape) {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_dtype = dtype;
auto ndim = shape.size();
tensor->_shape = std::vector<size_t>(shape);
size_t size = std::accumulate(shape.begin(), shape.end(), dsize(dtype), std::multiplies<size_t>());
auto strides = std::vector<ptrdiff_t>(ndim);
if (ndim > 0) {
......@@ -105,24 +140,22 @@ std::shared_ptr<Tensor> Tensor::weight(void *data, infiniDtype_t dtype,
strides[i] = strides[i + 1] * shape[i + 1];
}
}
tensor->_strides = strides;
tensor->_storage = Storage::create(size);
RUN_INFINI(infinirtMemcpy(tensor->_storage->memory,
tensor->_desc = TensorDesc::create(dtype, shape, strides);
RUN_INFINI(infinirtMemcpy(tensor->_storage->memory(),
data, size, INFINIRT_MEMCPY_H2D));
tensor->_data = tensor->_storage->memory;
infiniopCreateTensorDescriptor(&tensor->_desc, ndim, tensor->_shape.data(),
strides.data(), dtype);
tensor->_offset = 0;
return tensor;
}
std::shared_ptr<Tensor> Tensor::memShare(const std::vector<size_t> &shape, infiniDtype_t dtype) const {
std::shared_ptr<Tensor> Tensor::memShare(const std::vector<size_t> &shape, infiniDtype_t dtype_) const {
auto dtype = dtype_ == INFINI_DTYPE_INVALID ? this->dtype() : dtype_;
size_t size = std::accumulate(shape.begin(), shape.end(), dsize(dtype), std::multiplies<size_t>());
ASSERT(size <= this->_storage->size);
ASSERT(size <= this->_storage->size());
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
tensor->_dtype = dtype == INFINI_DTYPE_INVALID ? this->_dtype : dtype;
tensor->_shape = std::vector<size_t>(shape);
auto ndim = shape.size();
auto strides = std::vector<ptrdiff_t>(ndim);
if (ndim > 0) {
......@@ -131,16 +164,14 @@ std::shared_ptr<Tensor> Tensor::memShare(const std::vector<size_t> &shape, infin
strides[i] = strides[i + 1] * shape[i + 1];
}
}
tensor->_strides = strides;
tensor->_storage = this->_storage;
infiniopCreateTensorDescriptor(&tensor->_desc, ndim, tensor->_shape.data(),
tensor->_strides.data(), tensor->_dtype);
tensor->_offset = 0;
tensor->_desc = TensorDesc::create(dtype, shape, strides);
return tensor;
}
void *Tensor::dataImpl(ptrdiff_t offset) const {
return (char *)(this->_data) + offset * dsize(this->dtype());
return (char *)(this->_storage->memory()) + this->_offset + offset * dsize(this->dtype());
}
void *Tensor::data(ptrdiff_t offset) {
......@@ -157,22 +188,14 @@ void Tensor::copyFrom(std::shared_ptr<Tensor const> src,
ASSERT_EQ(this->dtype(), src->dtype());
infiniopRearrangeDescriptor_t desc;
RUN_INFINI(infiniopCreateRearrangeDescriptor(
handle, &desc, this->desc()->get(), src->desc()->get()));
handle, &desc, this->desc(), src->desc()));
RUN_INFINI(infiniopRearrange(desc, this->data(), src->data(),
stream));
RUN_INFINI(infiniopDestroyRearrangeDescriptor(desc));
}
bool Tensor::is_contigous() const {
auto ndim = this->ndim();
auto shape = this->shape();
auto strides = std::vector<ptrdiff_t>(ndim);
strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * shape[i + 1];
}
ASSERT_EQ(strides.size(), this->_strides.size());
return std::equal(strides.begin(), strides.end(), this->_strides.begin());
bool Tensor::isContigous() const {
return this->_desc->isContigous();
}
template <typename T>
......@@ -209,34 +232,25 @@ std::string Tensor::info() const {
std::stringstream ss;
ss << "Tensor: "
<< "shape[ ";
for (auto s : this->shape()) {
ss << s << " ";
}
ss << "] strides[ ";
for (auto s : this->strides()) {
ss << s << " ";
}
ss << "] dtype=" << this->dtype()
<< this->_desc->info()
<< " device=" << this->deviceType()
<< " device_id=" << this->deviceId();
return ss.str();
return this->_desc->info();
}
void Tensor::debug(const std::string &filename) const {
RUN_INFINI(
infinirtDeviceSynchronize());
RUN_INFINI(infinirtDeviceSynchronize());
std::cout << info() << std::endl;
auto dtype = this->dtype();
void const *cpu_data;
if (this->deviceType() != INFINI_DEVICE_CPU) {
void *cpu_memory = std::malloc(this->_storage->size);
RUN_INFINI(infinirtMemcpy(cpu_memory, this->_storage->memory,
this->_storage->size, INFINIRT_MEMCPY_D2H));
void *cpu_memory = std::malloc(this->_storage->size());
RUN_INFINI(infinirtMemcpy(cpu_memory, this->_storage->memory(),
this->_storage->size(), INFINIRT_MEMCPY_D2H));
cpu_data = cpu_memory;
} else {
cpu_data = this->_storage->memory;
cpu_data = this->_storage->memory();
}
if (!filename.empty()) {
......@@ -245,13 +259,13 @@ void Tensor::debug(const std::string &filename) const {
std::cerr << "Error opening file for writing: " << filename << "\n";
return;
}
outFile.write(reinterpret_cast<const char *>(cpu_data), this->_storage->size);
outFile.write(reinterpret_cast<const char *>(cpu_data), this->_storage->size());
outFile.close();
std::cout << "Data written to file: " << filename << "\n";
return;
}
switch (dtype) {
switch (this->dtype()) {
case INFINI_DTYPE_F16:
print_data((uint16_t const *)((char const *)cpu_data + dataOffset()),
this->shape(), this->strides(), 0);
......
......@@ -7,25 +7,19 @@
std::shared_ptr<Tensor> Tensor::sliceImpl(const std::vector<SliceParams> &slices) const {
std::shared_ptr<Tensor> tensor = std::make_shared<Tensor>();
auto new_shape = std::vector<size_t>(this->_shape);
auto new_shape = std::vector<size_t>(this->shape());
ptrdiff_t offset = 0;
for (const auto &slice : slices) {
ASSERT(slice.len > 0);
ASSERT(this->_shape[slice.dim] >= slice.start + slice.len);
ASSERT(this->shape()[slice.dim] >= slice.start + slice.len);
new_shape[slice.dim] = slice.len;
offset += slice.start * this->_strides[slice.dim];
offset += slice.start * this->strides()[slice.dim];
}
tensor->_dtype = this->_dtype;
tensor->_shape = new_shape;
tensor->_strides = std::vector<ptrdiff_t>(this->_strides);
tensor->_offset = offset * dsize(this->_dtype) + this->_offset;
tensor->_data = (char *)(this->_storage->memory) + tensor->_offset;
tensor->_desc = TensorDesc::create(this->dtype(), new_shape, this->strides());
tensor->_offset = offset * dsize(this->dtype()) + this->_offset;
tensor->_storage = this->_storage;
infiniopCreateTensorDescriptor(&tensor->_desc, tensor->_shape.size(), tensor->_shape.data(),
tensor->_strides.data(), tensor->_dtype);
return tensor;
}
......@@ -45,10 +39,10 @@ std::shared_ptr<Tensor const> Tensor::slice(const std::vector<SliceParams> &slic
return this->sliceImpl(slices);
}
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
void TensorDesc::dimMerge(size_t dim_start, size_t dim_end) {
ASSERT(dim_start <= dim_end && dim_end < this->_shape.size());
if (dim_start == dim_end) {
return shared_from_this();
return;
}
auto new_shape = std::vector<size_t>();
......@@ -68,14 +62,15 @@ std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
}
this->_shape = new_shape;
this->_strides = new_strides;
infiniopDestroyTensorDescriptor(this->_desc);
infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
this->_strides.data(), this->_dtype);
this->resetDesc();
}
std::shared_ptr<Tensor> Tensor::dimMerge(size_t dim_start, size_t dim_end) {
this->_desc->dimMerge(dim_start, dim_end);
return shared_from_this();
}
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
void TensorDesc::dimSplit(size_t dim, const std::vector<size_t> &dims) {
ASSERT_EQ(this->_shape[dim], std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies<size_t>()));
auto new_shape = std::vector<size_t>();
auto new_strides = std::vector<ptrdiff_t>();
......@@ -93,13 +88,15 @@ std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &
}
this->_shape = new_shape;
this->_strides = new_strides;
infiniopDestroyTensorDescriptor(this->_desc);
infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
this->_strides.data(), this->_dtype);
this->resetDesc();
}
std::shared_ptr<Tensor> Tensor::dimSplit(size_t dim, const std::vector<size_t> &dims) {
this->_desc->dimSplit(dim, dims);
return shared_from_this();
}
std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
void TensorDesc::permute(const std::vector<size_t> &order) {
ASSERT_EQ(this->_shape.size(), order.size());
auto new_shape = std::vector<size_t>(order.size());
auto new_strides = std::vector<ptrdiff_t>(order.size());
......@@ -110,8 +107,10 @@ std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
}
this->_shape = new_shape;
this->_strides = new_strides;
infiniopDestroyTensorDescriptor(this->_desc);
infiniopCreateTensorDescriptor(&this->_desc, this->_shape.size(), this->_shape.data(),
this->_strides.data(), this->_dtype);
this->resetDesc();
}
std::shared_ptr<Tensor> Tensor::permute(const std::vector<size_t> &order) {
this->_desc->permute(order);
return shared_from_this();
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment