Commit 3443835c authored by Chao Liu's avatar Chao Liu
Browse files

bug fix: add missing data-type in inner_product_with_conversion

parent 59968d8d
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#endif #endif
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0 #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
#endif #endif
#ifndef CK_USE_AMD_V_FMAC_F32 #ifndef CK_USE_AMD_V_FMAC_F32
......
...@@ -322,6 +322,8 @@ struct inner_product_with_conversion ...@@ -322,6 +322,8 @@ struct inner_product_with_conversion
return acc; return acc;
} }
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
// hack for int8x4_t, because compiler does not have native support for int8x4_t // hack for int8x4_t, because compiler does not have native support for int8x4_t
// int8x4_t is defined as int32_t // int8x4_t is defined as int32_t
__device__ T operator()(int8x4_t a, int8x4_t b) const __device__ T operator()(int8x4_t a, int8x4_t b) const
......
...@@ -49,7 +49,7 @@ int main(int argc, char* argv[]) ...@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 270; constexpr index_t HI = 270;
...@@ -730,10 +730,12 @@ int main(int argc, char* argv[]) ...@@ -730,10 +730,12 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
#if 0 #if 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<float, float, float>( device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<float, float, float>(
#elif 1 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int32_t>( device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t,
int32_t,
int32_t>(
#elif 1 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int8_t>( device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int8_t>(
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment