Commit e6d9dd20 authored by ltqin's avatar ltqin
Browse files

fix by comments

parent 5cfd01fd
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
namespace ck { namespace ck {
// GemmM = K // GemmM = K
// GemmN = N * Ho * Wo // GemmK = N * Ho * Wo
// GemmK = C * Y * X // GemmN = C * Y * X
template <typename... Wei, template <typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
......
...@@ -48,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -48,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
......
...@@ -87,13 +87,13 @@ int main(int argc, char* argv[]) ...@@ -87,13 +87,13 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
constexpr index_t N = 128; constexpr index_t N = Number<128>;
constexpr index_t C = 192; constexpr index_t C = Number<128>;
constexpr index_t Hi = 71; constexpr index_t Hi = Number<14>;
constexpr index_t Wi = 71; constexpr index_t Wi = Number<14>;
constexpr index_t K = 256; constexpr index_t K = Number<256>;
constexpr index_t Y = 3; constexpr index_t Y = Number<3>;
constexpr index_t X = 3; constexpr index_t X = Number<3>;
const index_t conv_stride_h = 2; const index_t conv_stride_h = 2;
const index_t conv_stride_w = 2; const index_t conv_stride_w = 2;
...@@ -200,8 +200,8 @@ int main(int argc, char* argv[]) ...@@ -200,8 +200,8 @@ int main(int argc, char* argv[])
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break; break;
case 5: case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread); in.GenerateTensorValue(GeneratorTensor_3<float>{-0.01, 0.01}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread); out.GenerateTensorValue(GeneratorTensor_3<float>{-0.01, 0.01}, num_thread);
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment