Commit 31fcf744 authored by YdrMaster's avatar YdrMaster
Browse files

issue/50/refactor: workspace_size 改为 minWorkSpaceSize()


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent b5ccf30f
...@@ -41,6 +41,10 @@ infiniStatus_t Descriptor::create( ...@@ -41,6 +41,10 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
size_t Descriptor::minWorkspaceSize() const {
return _min_workspace_size;
}
template <typename DT> template <typename DT>
struct ComputeType { struct ComputeType {
using type = DT; using type = DT;
......
...@@ -37,9 +37,10 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -37,9 +37,10 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc, infiniopRandomSampleDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::random_sample::NAMESPACE::Descriptor *>(desc)->workspace_size; \ using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \
*size = reinterpret_cast<Ptr>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
Opaque *_opaque; \ Opaque *_opaque; \
\ \
infiniDtype_t _dt_i, _dt_p; \ infiniDtype_t _dt_i, _dt_p; \
size_t _n; \ size_t _n, _min_workspace_size; \
\ \
Descriptor( \ Descriptor( \
infiniDtype_t dt_i, \ infiniDtype_t dt_i, \
infiniDtype_t dt_p, \ infiniDtype_t dt_p, \
size_t n, \ size_t n, \
size_t workspace_size_, \ size_t min_workspace_size, \
Opaque *opaque, \ Opaque *opaque, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) \ int device_id) \
...@@ -27,11 +27,9 @@ ...@@ -27,11 +27,9 @@
_dt_i(dt_i), \ _dt_i(dt_i), \
_dt_p(dt_p), \ _dt_p(dt_p), \
_n(n), \ _n(n), \
workspace_size(workspace_size_) {} \ _min_workspace_size(min_workspace_size) {} \
\ \
public: \ public: \
size_t workspace_size; \
\
~Descriptor(); \ ~Descriptor(); \
\ \
static infiniStatus_t create( \ static infiniStatus_t create( \
...@@ -40,6 +38,8 @@ ...@@ -40,6 +38,8 @@
infiniopTensorDescriptor_t result_desc, \ infiniopTensorDescriptor_t result_desc, \
infiniopTensorDescriptor_t probs_desc); \ infiniopTensorDescriptor_t probs_desc); \
\ \
size_t minWorkspaceSize() const; \
\
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, \ void *workspace, \
size_t workspace_size, \ size_t workspace_size, \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment