Unverified Commit 85fc91c3 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Minor fix for recent PR (#260)

* fix example

* update IsSupportedArgument

* fix

* disable fp64 conv example as test
parent d32a67a9
...@@ -4,4 +4,5 @@ add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) ...@@ -4,4 +4,5 @@ add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
...@@ -170,9 +170,7 @@ int main(int argc, char* argv[]) ...@@ -170,9 +170,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -169,9 +169,7 @@ int main(int argc, char* argv[]) ...@@ -169,9 +169,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -167,9 +167,7 @@ int main(int argc, char* argv[]) ...@@ -167,9 +167,7 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_gemm with the specified compilation parameters does " std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"not support this GEMM problem"
<< std::endl;
return 0; return 0;
} }
......
...@@ -193,9 +193,9 @@ int main(int argc, char* argv[]) ...@@ -193,9 +193,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -166,9 +166,9 @@ int main(int argc, char* argv[]) ...@@ -166,9 +166,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
...@@ -21,8 +21,6 @@ template <ck::index_t... Is> ...@@ -21,8 +21,6 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F64 = double; using F64 = double;
using F32 = float;
using F16 = ck::half_t;
using ADataType = double; using ADataType = double;
using BDataType = double; using BDataType = double;
...@@ -195,9 +193,9 @@ int main(int argc, char* argv[]) ...@@ -195,9 +193,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -233,7 +231,7 @@ int main(int argc, char* argv[]) ...@@ -233,7 +231,7 @@ int main(int argc, char* argv[])
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
} }
#endif #endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
} }
return 0; return 0;
......
...@@ -194,9 +194,9 @@ int main(int argc, char* argv[]) ...@@ -194,9 +194,9 @@ int main(int argc, char* argv[])
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
......
#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP #pragma once
#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP
#include <functional> #include <functional>
#include <iostream> #include <iostream>
...@@ -8,6 +7,7 @@ ...@@ -8,6 +7,7 @@
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_prop.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_fwd.hpp" #include "device_conv_fwd.hpp"
#include "convolution_forward_specialization.hpp" #include "convolution_forward_specialization.hpp"
...@@ -858,6 +858,27 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -858,6 +858,27 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else if(ck::get_device_name() == "gfx90a")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
{
return false;
}
}
else
{
return false;
}
// Input tensors can't be bigger than 2GB each. // Input tensors can't be bigger than 2GB each.
constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31);
...@@ -1021,4 +1042,3 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -1021,4 +1042,3 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_prop.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
#include "common_header.hpp" #include "common_header.hpp"
...@@ -13,7 +14,6 @@ ...@@ -13,7 +14,6 @@
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "gridwise_gemm_dl_v1r3.hpp" #include "gridwise_gemm_dl_v1r3.hpp"
#include "device_prop.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_prop.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
#include "common_header.hpp" #include "common_header.hpp"
...@@ -11,7 +12,6 @@ ...@@ -11,7 +12,6 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "device_prop.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -408,7 +408,23 @@ struct DeviceGemmXdl ...@@ -408,7 +408,23 @@ struct DeviceGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(ck::get_device_name() == "gfx908")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t>))
{
return false;
}
}
else if(ck::get_device_name() == "gfx90a")
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
{
return false;
}
}
else
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment