Commit 573e1b64 authored by ltqin's avatar ltqin
Browse files

modify kBatch value

parent eed64f7e
...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 96; constexpr index_t KBatch = 64;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -253,8 +253,7 @@ int main(int argc, char* argv[]) ...@@ -253,8 +253,7 @@ int main(int argc, char* argv[])
in_left_pads_dev, in_left_pads_dev,
in_right_pads_dev); in_right_pads_dev);
}; };
// set zero to wei_device // set zero to wei_device
wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread);
#if USE_CONV_WRW_V4R4R2_XDL_NCHW #if USE_CONV_WRW_V4R4R2_XDL_NCHW
...@@ -284,7 +283,6 @@ int main(int argc, char* argv[]) ...@@ -284,7 +283,6 @@ int main(int argc, char* argv[])
} }
#endif #endif
#if USE_CONV_WRW_V4R4R4_XDL_NHWC #if USE_CONV_WRW_V4R4R4_XDL_NHWC
if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment