"...composable_kernel_rocm.git" did not exist on "42fc8eddd21f5725881f8f503f2cb5724c935cb5"
Commit 3317bfe2 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2b8e3ece
...@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw ...@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc = constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc, transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
......
...@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4), // 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16 // we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively. // respectively.
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_now);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_now);
p_a_block_now);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
...@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
...@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
p_a_block_vec = p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double + a_block_space);
reinterpret_cast<const half4_t*>( p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double + b_block_space);
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
...@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
} }
} }
...@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
} }
} }
}; };
} }
#endif #endif
...@@ -87,7 +87,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -87,7 +87,8 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw< constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
half, half,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment