Commit cd29b09a authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent a6b95c39
......@@ -608,11 +608,11 @@ int main(int argc, char* argv[])
device_convolution_direct_v2_nchw_kcyx_nkhw
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
#elif 0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
#elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
......
......@@ -100,8 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
const index_t wi_block_data_begin = wo_block_data_begin;
// global tensor view
constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
// LDS tensor view
// be careful of alignment
......@@ -359,13 +358,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
// perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until
// instantiated
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
// fwd do nothing but perfect forwarding.
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// begin instantiated here
static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!");
// output is a 10d tensor
......@@ -373,37 +371,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
N1,
N2,
K / (K1 * K2),
K1,
K2,
Ho,
Wo / (W1 * W2),
W1,
W2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc)
.Fold(I3, Number<W1>{}, Number<W2>{})
.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<1>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
"a: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"a: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
}
#endif
......@@ -421,8 +414,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
out_10d_thread_desc.GetLengths(),
map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{});
}).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
}).else_([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0,
"wrong!");
......@@ -431,28 +424,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
Sequence<N / N1, N1, K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3>{});
constexpr auto out_10d_global_desc =
fwd(out_n_k_h_w_global_desc)
.Fold(I3, Number<W1>{}, Number<W2>{}, Number<W3>{})
.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
constexpr auto out_10d_thread_desc =
fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
"b: out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
"b: out_n_k_h_w_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment