Unverified Commit 2ed46170 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Merge pull request #127 from ROCmSoftwarePlatform/scalar_tensor

Scalar tensor
parents ca33154a fbef5744
pfultz2/rocm-recipes
pcre
danmar/cppcheck@f965e5873 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@3a41f286203968421c557338d6fb39c36f3c717c
ROCm-Developer-Tools/HIP@3c7f5dbce24802ec4237e615038daff2909a2e8e
# python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt
......@@ -141,7 +141,6 @@ struct onnx_parser
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
......@@ -588,6 +587,11 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.size() == 0)
{
dims = {1};
}
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
......
......@@ -20,7 +20,7 @@ add_library(migraphx_device
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device migraphx hip::device)
target_link_libraries(migraphx_device migraphx hip::device -Wno-invalid-command-line-argument -amdgpu-target=gfx803 -amdgpu-target=gfx900 -amdgpu-target=gfx903)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
......
......@@ -75,6 +75,15 @@ device_type<T>* device_cast(T* x)
return reinterpret_cast<device_type<T>*>(x);
}
template <class T>
T to_hip_type(T x)
{
return x;
}
// Hip doens't support __fp16
inline float to_hip_type(gpu_half x) { return x; }
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
......
......@@ -9,7 +9,7 @@ namespace device {
void sin(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return ::sin(x); });
nary(stream, result, arg)([](auto x) { return ::sin(to_hip_type(x)); });
}
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment