Commit 3317bfe2 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2b8e3ece
......@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
......
......@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_now);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_now);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_now);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS
......@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS
......@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double + b_block_space);
p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double + a_block_space);
p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
else // if has 1 iteration left
......@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads();
// LDS double buffer: GEMM on last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
}
......@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
}
}
};
}
#endif
......@@ -87,7 +87,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
GridSize,
BlockSize,
half,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment