Commit bae8112d authored by Jing Zhang's avatar Jing Zhang
Browse files

enable fwd conv on navi4x

parent 255fbc56
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950) list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102 gfx1200)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
......
...@@ -90,10 +90,10 @@ struct ExecutionConfig final ...@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true; bool time_kernel = true;
}; };
#define DefaultConvParam \ #define DefaultConvParam \
ck::utils::conv::ConvParam \ ck::utils::conv::ConvParam \
{ \ { \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ 2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
} }
inline void print_help_msg() inline void print_help_msg()
......
...@@ -90,10 +90,10 @@ struct ExecutionConfig final ...@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool time_kernel = true; bool time_kernel = true;
}; };
#define DefaultConvParam \ #define DefaultConvParam \
ck::utils::conv::ConvParam \ ck::utils::conv::ConvParam \
{ \ { \
2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ 2, 32, 2, 32, 32, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
} }
inline void print_help_msg() inline void print_help_msg()
......
...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(ck::is_navi3_supported()) if(ck::is_navi3_supported() || ck::is_navi4_supported())
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -340,7 +340,8 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -340,7 +340,8 @@ struct GridwiseGemmMultipleD_Wmma
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16; // static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto WmmaK = 16;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......
...@@ -156,7 +156,7 @@ check_err(const Range& out, ...@@ -156,7 +156,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
// if(err_count < 5) if(err_count < 5)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment