Commit e6d9dd20 authored by ltqin's avatar ltqin
Browse files

fix by comments

parent 5cfd01fd
......@@ -8,8 +8,8 @@
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
......
......@@ -48,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
......
......@@ -87,13 +87,13 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]);
constexpr index_t N = 128;
constexpr index_t C = 192;
constexpr index_t Hi = 71;
constexpr index_t Wi = 71;
constexpr index_t K = 256;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t N = Number<128>;
constexpr index_t C = Number<128>;
constexpr index_t Hi = Number<14>;
constexpr index_t Wi = Number<14>;
constexpr index_t K = Number<256>;
constexpr index_t Y = Number<3>;
constexpr index_t X = Number<3>;
const index_t conv_stride_h = 2;
const index_t conv_stride_w = 2;
......@@ -200,8 +200,8 @@ int main(int argc, char* argv[])
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 0.01}, num_thread);
in.GenerateTensorValue(GeneratorTensor_3<float>{-0.01, 0.01}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.01, 0.01}, num_thread);
break;
default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment