Commit 94e3a2e4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change size_t to int

parent 26bd92d8
...@@ -36,7 +36,7 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -36,7 +36,7 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
return; return;
auto index = si.multi(j); auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k) for(int k = 0; k < index.size(); ++k)
{ {
ptr[k * elem_num + out_loc] = index[k]; ptr[k * elem_num + out_loc] = index[k];
} }
......
...@@ -15,7 +15,7 @@ namespace device { ...@@ -15,7 +15,7 @@ namespace device {
argument argument
pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads) pad(hipStream_t stream, argument result, argument arg1, float value, std::vector<std::int64_t> pads)
{ {
std::size_t nelements = arg1.get_shape().elements(); int nelements = arg1.get_shape().elements();
hip_visit_all(result, arg1)([&](auto output, auto input) { hip_visit_all(result, arg1)([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
using hip_index = typename decltype(output)::hip_index; using hip_index = typename decltype(output)::hip_index;
...@@ -27,7 +27,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector ...@@ -27,7 +27,7 @@ pad(hipStream_t stream, argument result, argument arg1, float value, std::vector
std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin()); std::copy(pads.begin(), pads.begin() + offsets.size(), offsets.begin());
gs_launch(stream, nelements)([=](auto i) __device__ { gs_launch(stream, nelements)([=](auto i) __device__ {
auto idx = input.get_shape().multi(i); auto idx = input.get_shape().multi(i);
for(std::size_t j = 0; j < offsets.size(); j++) for(int j = 0; j < offsets.size(); j++)
{ {
idx[j] += offsets[j]; idx[j] += offsets[j];
} }
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg) void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{ {
index_int item_num = arg.get_shape().elements() / result.get_shape().elements(); index_int item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num}); reduce(stream, result, arg, sum{}, 0, id{}, mean{static_cast<int>(item_num)});
} }
} // namespace device } // namespace device
......
...@@ -16,9 +16,9 @@ reverse(hipStream_t stream, argument result, argument arg1, const std::vector<in ...@@ -16,9 +16,9 @@ reverse(hipStream_t stream, argument result, argument arg1, const std::vector<in
{ {
auto s = arg1.get_shape(); auto s = arg1.get_shape();
// auto lens = s.lens(); // auto lens = s.lens();
std::vector<std::size_t> axis_len(axes.begin(), axes.end()); std::vector<int> axis_len(axes.begin(), axes.end());
shape sa{shape::float_type, axis_len}; shape sa{shape::float_type, axis_len};
std::size_t nelements = s.elements(); int nelements = s.elements();
visit_all(result, arg1)([&](auto output1, auto input1) { visit_all(result, arg1)([&](auto output1, auto input1) {
hip_visit_views(output1, input1, s)([&](auto output, auto input, auto hs) { hip_visit_views(output1, input1, s)([&](auto output, auto input, auto hs) {
hip_visit_views(sa)([&](auto daxes) { hip_visit_views(sa)([&](auto daxes) {
......
...@@ -142,7 +142,7 @@ std::vector<argument> topk(hipStream_t stream, ...@@ -142,7 +142,7 @@ std::vector<argument> topk(hipStream_t stream,
auto comp_lens = in_lens; auto comp_lens = in_lens;
comp_lens[axis] = 1; comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens}; shape comp_s{in_s.type(), comp_lens};
std::size_t elem_num = comp_s.elements(); int elem_num = comp_s.elements();
hip_visit_all(val_res, arg, out_s, in_s, comp_s)( hip_visit_all(val_res, arg, out_s, in_s, comp_s)(
[&](auto out_val, auto input, auto oss, auto iss, auto css) { [&](auto out_val, auto input, auto oss, auto iss, auto css) {
......
...@@ -15,8 +15,8 @@ namespace driver { ...@@ -15,8 +15,8 @@ namespace driver {
shape parser::parse_shape(const value& v) const shape parser::parse_shape(const value& v) const
{ {
auto lens = get(v, "lens", std::vector<std::size_t>{}); auto lens = get(v, "lens", std::vector<int>{});
auto strides = get(v, "strides", std::vector<std::size_t>{}); auto strides = get(v, "strides", std::vector<int>{});
auto type = shape::parse_type(get<std::string>(v, "type", "float")); auto type = shape::parse_type(get<std::string>(v, "type", "float"));
if(strides.empty()) if(strides.empty())
return shape{type, lens}; return shape{type, lens};
......
...@@ -13,7 +13,7 @@ namespace gpu { ...@@ -13,7 +13,7 @@ namespace gpu {
void eliminate_workspace::apply(module& p) const void eliminate_workspace::apply(module& p) const
{ {
std::size_t n = 0; int n = 0;
std::vector<instruction_ref> allocs; std::vector<instruction_ref> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
......
...@@ -64,7 +64,7 @@ struct fusion ...@@ -64,7 +64,7 @@ struct fusion
bool empty() const { return fp == nullptr; } bool empty() const { return fp == nullptr; }
op_t operator[](std::size_t i) const op_t operator[](int i) const
{ {
assert(fp); assert(fp);
op_t result; op_t result;
...@@ -118,7 +118,7 @@ struct fusion ...@@ -118,7 +118,7 @@ struct fusion
{ {
// assert(fp); // assert(fp);
// TODO: Use zero workspace for now // TODO: Use zero workspace for now
std::size_t ws_size = 0; int ws_size = 0;
// int algo_count = 1; // int algo_count = 1;
// miopenConvFwdAlgorithm_t algo; // miopenConvFwdAlgorithm_t algo;
// miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo); // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
...@@ -596,7 +596,7 @@ struct miopen_fusion ...@@ -596,7 +596,7 @@ struct miopen_fusion
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
std::size_t i = 0; int i = 0;
f = fusion(inputs[i]); f = fusion(inputs[i]);
i++; i++;
std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>> std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>>
......
...@@ -90,7 +90,7 @@ void gemm_impl(context& ctx, ...@@ -90,7 +90,7 @@ void gemm_impl(context& ctx,
} }
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), int{1}, std::multiplies<int>());
if(num_matrices == 1) if(num_matrices == 1)
{ {
// the rocblas_gemm API handles inputs and output matrices as // the rocblas_gemm API handles inputs and output matrices as
......
...@@ -27,10 +27,10 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister); ...@@ -27,10 +27,10 @@ using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); } std::string hip_error(int error) { return hipGetErrorString(static_cast<hipError_t>(error)); }
std::size_t get_available_gpu_memory() int get_available_gpu_memory()
{ {
size_t free; std::size_t free;
size_t total; std::size_t total;
auto status = hipMemGetInfo(&free, &total); auto status = hipMemGetInfo(&free, &total);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed getting available memory: " + hip_error(status)); MIGRAPHX_THROW("Failed getting available memory: " + hip_error(status));
...@@ -46,7 +46,7 @@ void* get_device_ptr(void* hptr) ...@@ -46,7 +46,7 @@ void* get_device_ptr(void* hptr)
return result; return result;
} }
hip_ptr allocate_gpu(std::size_t sz, bool host = false) hip_ptr allocate_gpu(int sz, bool host = false)
{ {
if(sz > get_available_gpu_memory()) if(sz > get_available_gpu_memory())
MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz)); MIGRAPHX_THROW("Memory not available to allocate buffer: " + std::to_string(sz));
...@@ -62,7 +62,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false) ...@@ -62,7 +62,7 @@ hip_ptr allocate_gpu(std::size_t sz, bool host = false)
return hip_ptr{result}; return hip_ptr{result};
} }
hip_host_ptr register_on_gpu(void* ptr, std::size_t sz) hip_host_ptr register_on_gpu(void* ptr, int sz)
{ {
auto status = hipHostRegister(ptr, sz, hipHostRegisterMapped); auto status = hipHostRegister(ptr, sz, hipHostRegisterMapped);
if(status != hipSuccess) if(status != hipSuccess)
...@@ -71,7 +71,7 @@ hip_host_ptr register_on_gpu(void* ptr, std::size_t sz) ...@@ -71,7 +71,7 @@ hip_host_ptr register_on_gpu(void* ptr, std::size_t sz)
} }
template <class T> template <class T>
std::vector<T> read_from_gpu(const void* x, std::size_t sz) std::vector<T> read_from_gpu(const void* x, int sz)
{ {
gpu_sync(); gpu_sync();
std::vector<T> result(sz); std::vector<T> result(sz);
...@@ -81,7 +81,7 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz) ...@@ -81,7 +81,7 @@ std::vector<T> read_from_gpu(const void* x, std::size_t sz)
return result; return result;
} }
hip_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false) hip_ptr write_to_gpu(const void* x, int sz, bool host = false)
{ {
gpu_sync(); gpu_sync();
auto result = allocate_gpu(sz, host); auto result = allocate_gpu(sz, host);
...@@ -133,7 +133,7 @@ argument from_gpu(const argument& arg) ...@@ -133,7 +133,7 @@ argument from_gpu(const argument& arg)
return result; return result;
} }
void set_device(std::size_t id) void set_device(int id)
{ {
auto status = hipSetDevice(id); auto status = hipSetDevice(id);
if(status != hipSuccess) if(status != hipSuccess)
...@@ -151,8 +151,8 @@ void gpu_sync(const context& ctx) { ctx.finish(); } ...@@ -151,8 +151,8 @@ void gpu_sync(const context& ctx) { ctx.finish(); }
void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind) void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind)
{ {
std::size_t src_size = src.get_shape().bytes(); int src_size = src.get_shape().bytes();
std::size_t dst_size = dst.get_shape().bytes(); int dst_size = dst.get_shape().bytes();
if(src_size > dst_size) if(src_size > dst_size)
MIGRAPHX_THROW("Not enough memory available in destination to do copy"); MIGRAPHX_THROW("Not enough memory available in destination to do copy");
auto status = hipMemcpyAsync(dst.data(), src.data(), src_size, kind, ctx.get_stream().get()); auto status = hipMemcpyAsync(dst.data(), src.data(), src_size, kind, ctx.get_stream().get());
......
...@@ -17,8 +17,8 @@ struct code_object_op ...@@ -17,8 +17,8 @@ struct code_object_op
{ {
value::binary code_object; value::binary code_object;
std::string symbol_name; std::string symbol_name;
std::size_t global; int global;
std::size_t local; int local;
std::vector<shape> expected_inputs; std::vector<shape> expected_inputs;
shape output; shape output;
kernel k{}; kernel k{};
......
...@@ -15,9 +15,9 @@ namespace gpu { ...@@ -15,9 +15,9 @@ namespace gpu {
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch); compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch);
std::string enum_params(std::size_t count, std::string param); std::string enum_params(int count, std::string param);
std::size_t compute_global(std::size_t n, std::size_t local = 1024); int compute_global(int n, int local = 1024);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -10,8 +10,8 @@ namespace gpu { ...@@ -10,8 +10,8 @@ namespace gpu {
struct hip_compile_options struct hip_compile_options
{ {
std::size_t global; int global;
std::size_t local; int local;
std::vector<shape> inputs; std::vector<shape> inputs;
shape output; shape output;
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
......
...@@ -29,13 +29,13 @@ struct hip_device ...@@ -29,13 +29,13 @@ struct hip_device
add_stream(); add_stream();
} }
hip_device(std::size_t id, std::size_t n) : device_id(id) hip_device(int id, int n) : device_id(id)
{ {
auto status = hipGetDeviceProperties(&device_props, device_id); auto status = hipGetDeviceProperties(&device_props, device_id);
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream"); MIGRAPHX_THROW("Failed to allocate stream");
for(std::size_t i = 0; i < n; i++) for(int i = 0; i < n; i++)
add_stream(); add_stream();
} }
...@@ -45,7 +45,7 @@ struct hip_device ...@@ -45,7 +45,7 @@ struct hip_device
stream() {} stream() {}
stream(std::size_t device_number) : id(device_number) {} stream(int device_number) : id(device_number) {}
void setup() const { set_device(id); } void setup() const { set_device(id); }
...@@ -124,7 +124,7 @@ struct hip_device ...@@ -124,7 +124,7 @@ struct hip_device
} }
private: private:
std::size_t id = 0; int id = 0;
shared<hip_stream_ptr> s = nullptr; shared<hip_stream_ptr> s = nullptr;
shared<miopen_handle> mihandle = nullptr; shared<miopen_handle> mihandle = nullptr;
shared<rocblas_handle_ptr> rbhandle = nullptr; shared<rocblas_handle_ptr> rbhandle = nullptr;
...@@ -134,29 +134,29 @@ struct hip_device ...@@ -134,29 +134,29 @@ struct hip_device
stream& get_stream() { return streams.at(current_stream); } stream& get_stream() { return streams.at(current_stream); }
stream& get_stream(std::size_t n) { return streams.at(n); } stream& get_stream(int n) { return streams.at(n); }
const stream& get_stream() const { return streams.at(current_stream); } const stream& get_stream() const { return streams.at(current_stream); }
const stream& get_stream(std::size_t n) const { return streams.at(n); } const stream& get_stream(int n) const { return streams.at(n); }
void set_stream(std::size_t n) { current_stream = n; } void set_stream(int n) { current_stream = n; }
std::size_t nstreams() const { return streams.size(); } int nstreams() const { return streams.size(); }
std::size_t stream_id() const { return current_stream; } int stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; } std::string get_device_name() const { return device_props.gcnArchName; }
std::size_t get_device_major() const { return device_props.major; } int get_device_major() const { return device_props.major; }
std::size_t get_device_minor() const { return device_props.minor; } int get_device_minor() const { return device_props.minor; }
std::size_t get_cu_count() const { return device_props.multiProcessorCount; } int get_cu_count() const { return device_props.multiProcessorCount; }
private: private:
std::size_t device_id = 0; int device_id = 0;
std::size_t current_stream = 0; int current_stream = 0;
std::vector<stream> streams; std::vector<stream> streams;
hipDeviceProp_t device_props; hipDeviceProp_t device_props;
...@@ -166,7 +166,7 @@ struct hip_device ...@@ -166,7 +166,7 @@ struct hip_device
struct context struct context
{ {
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1)) context(int device_id = 0, int n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n)) : current_device(std::make_shared<hip_device>(device_id, n))
{ {
} }
...@@ -184,23 +184,23 @@ struct context ...@@ -184,23 +184,23 @@ struct context
} }
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } hip_device::stream& get_stream(int n) { return get_current_device().get_stream(n); }
const hip_device::stream& get_stream() const { return get_current_device().get_stream(); } const hip_device::stream& get_stream() const { return get_current_device().get_stream(); }
const hip_device::stream& get_stream(std::size_t n) const const hip_device::stream& get_stream(int n) const
{ {
return get_current_device().get_stream(n); return get_current_device().get_stream(n);
} }
void set_stream(std::size_t n) { get_current_device().set_stream(n); } void set_stream(int n) { get_current_device().set_stream(n); }
void create_events(std::size_t num_of_events) void create_events(int num_of_events)
{ {
for(std::size_t i = events.size(); i < num_of_events + 1; ++i) for(int i = events.size(); i < num_of_events + 1; ++i)
events.emplace_back(create_event()); events.emplace_back(create_event());
} }
hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); } hipEvent_t get_event(int i) const { return events.at(i).get(); }
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const { get_stream().wait(); } void finish() const { get_stream().wait(); }
...@@ -226,11 +226,11 @@ struct context ...@@ -226,11 +226,11 @@ struct context
void from_value(const value& v) void from_value(const value& v)
{ {
auto v_events = v.at("events"); auto v_events = v.at("events");
std::size_t n_events = v_events.without_key().to<std::size_t>(); int n_events = v_events.without_key().to<int>();
this->create_events(n_events - 1); this->create_events(n_events - 1);
auto v_streams = v.at("streams"); auto v_streams = v.at("streams");
std::size_t n_streams = v_streams.without_key().to<std::size_t>(); int n_streams = v_streams.without_key().to<int>();
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
......
...@@ -73,7 +73,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -73,7 +73,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
auto batch_lens = arg_shape.lens(); auto batch_lens = arg_shape.lens();
size_t batch_item_num = batch_lens[axis]; int batch_item_num = batch_lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens}; migraphx::shape batch_shape{arg_shape.type(), batch_lens};
migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()}; migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()};
...@@ -82,8 +82,8 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -82,8 +82,8 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
auto* output = device_cast(result.get<int64_t>().data()); auto* output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 256; const int max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, gs_launch(stream,
batch_shape.elements() * block_size, batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ { block_size)([=](auto i, auto idx) __device__ {
......
...@@ -13,7 +13,7 @@ namespace device { ...@@ -13,7 +13,7 @@ namespace device {
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const shape& output_shape, const shape& output_shape,
std::vector<argument> args, std::vector<argument> args,
std::vector<std::size_t> offsets); std::vector<int> offsets);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -89,13 +89,13 @@ struct rocblas_gemm ...@@ -89,13 +89,13 @@ struct rocblas_gemm
return args.back(); return args.back();
} }
void batch_not_transposed(const std::vector<std::size_t>& strides) const void batch_not_transposed(const std::vector<int>& strides) const
{ {
if(strides.size() <= 2) if(strides.size() <= 2)
return; return;
auto dim_0 = strides.size() - 2; auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]); auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0); std::vector<int> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); })) if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{ {
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) + MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
......
...@@ -22,7 +22,7 @@ argument to_gpu(const argument& arg, bool host = false); ...@@ -22,7 +22,7 @@ argument to_gpu(const argument& arg, bool host = false);
argument from_gpu(const argument& arg); argument from_gpu(const argument& arg);
void set_device(std::size_t id); void set_device(int id);
void gpu_sync(); void gpu_sync();
void gpu_sync(const context& ctx); void gpu_sync(const context& ctx);
......
...@@ -25,16 +25,16 @@ struct kernel ...@@ -25,16 +25,16 @@ struct kernel
} }
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, int global,
std::size_t local, int local,
const std::vector<kernel_argument>& args) const; const std::vector<kernel_argument>& args) const;
void launch(hipStream_t stream, void launch(hipStream_t stream,
std::size_t global, int global,
std::size_t local, int local,
std::vector<void*> args) const; std::vector<void*> args) const;
auto launch(hipStream_t stream, std::size_t global, std::size_t local) const auto launch(hipStream_t stream, int global, int local) const
{ {
return [=](auto&&... xs) { return [=](auto&&... xs) {
launch(stream, global, local, std::vector<kernel_argument>{xs...}); launch(stream, global, local, std::vector<kernel_argument>{xs...});
......
...@@ -16,7 +16,7 @@ namespace migraphx { ...@@ -16,7 +16,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class Derived, std::size_t N> template <class Derived, int N>
struct device_base : oper<Derived> struct device_base : oper<Derived>
{ {
template <class Self, class F> template <class Self, class F>
...@@ -32,7 +32,7 @@ struct device_base : oper<Derived> ...@@ -32,7 +32,7 @@ struct device_base : oper<Derived>
reduce_shapes = reduce_dims(inputs); reduce_shapes = reduce_dims(inputs);
} }
argument get_arg(const std::vector<argument>& args, std::size_t i) const argument get_arg(const std::vector<argument>& args, int i) const
{ {
if(reduce_shapes.empty()) if(reduce_shapes.empty())
return args[i]; return args[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment