Commit be98313d authored by Jing Zhang's avatar Jing Zhang
Browse files

add b tile permute

parent e053e947
......@@ -170,6 +170,41 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
#if 1
int NPerBlock = 32;
int KPerBlock = 128;
int N1 = NPerBlock;
int K1 = KPerBlock;
int N0 = N / N1;
int K0 = K / K1;
for(int i = 0; i < N0; i++)
{
for(int j = 0; j < K0; j++)
{
for(int ii = 0; ii < N1; ii++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(i * K0 * N1 * K1 + j * N1 * K1 + ii * K1 + jj) =
b_k_n((i * N1 + ii) * K + (j * K1 + jj));
}
}
}
}
#else
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
#endif
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
......@@ -178,12 +213,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n(j + k * 2, i);
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// for(int k = 1; k <= 4; k++)
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
......@@ -218,30 +253,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
#if 0
ck::pk_i4_t i4s[4];
i4s[0] = 0xa8;
i4s[1] = 0xec;
i4s[2] = 0xb9;
i4s[3] = 0xfd;
ck::vector_type<ck::half_t, 8> result;
result.template AsType<ck::half4_t>()(ck::Number<0>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s));
result.template AsType<ck::half4_t>()(ck::Number<1>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s) >> 8);
printf("%f %f %f %f %f %f %f %f\n",
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<0>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<1>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<2>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<3>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<4>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<5>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<6>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<7>{}])
);
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
......
......@@ -387,7 +387,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
{
#if 1
// B Tile Permute
#if 0
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
......@@ -396,18 +397,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else
const index_t N1 = NPerBlock;
const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
constexpr index_t N1 = NPerBlock;
constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK00 = BK0 / BK01;
const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk00_n1_bk01_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK00, N1, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1,
make_tuple(make_pass_through_transform(BK0),
b_grid_desc_n0_bk00_n1_bk01_bk1,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_merge_transform(make_tuple(N0, N1)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
return b_grid_desc_bk0_n_bk1;
......@@ -614,7 +619,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
#if 1
#if 0
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
#else
const int k0_offset = karg.KRead * NPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment