Commit be98313d authored by Jing Zhang's avatar Jing Zhang
Browse files

add b tile permute

parent e053e947
...@@ -170,6 +170,41 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -170,6 +170,41 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
#if 1
int NPerBlock = 32;
int KPerBlock = 128;
int N1 = NPerBlock;
int K1 = KPerBlock;
int N0 = N / N1;
int K0 = K / K1;
for(int i = 0; i < N0; i++)
{
for(int j = 0; j < K0; j++)
{
for(int ii = 0; ii < N1; ii++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(i * K0 * N1 * K1 + j * N1 * K1 + ii * K1 + jj) =
b_k_n((i * N1 + ii) * K + (j * K1 + jj));
}
}
}
}
#else
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
#endif
// vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K; j += 8) for(int j = 0; j < K; j += 8)
...@@ -178,12 +213,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -178,12 +213,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++) for(int k = 0; k < 4; k++)
{ {
int i4x2 = b_k_n(j + k * 2, i); int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
} }
// for(int k = 1; k <= 4; k++) // permute 01234567->20643175
{ {
int hi = input[2]; int hi = input[2];
int lo = input[0]; int lo = input[0];
...@@ -218,30 +253,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -218,30 +253,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
} }
#if 0
ck::pk_i4_t i4s[4];
i4s[0] = 0xa8;
i4s[1] = 0xec;
i4s[2] = 0xb9;
i4s[3] = 0xfd;
ck::vector_type<ck::half_t, 8> result;
result.template AsType<ck::half4_t>()(ck::Number<0>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s));
result.template AsType<ck::half4_t>()(ck::Number<1>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s) >> 8);
printf("%f %f %f %f %f %f %f %f\n",
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<0>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<1>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<2>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<3>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<4>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<5>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<6>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<7>{}])
);
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace; DeviceMem workspace;
......
...@@ -387,7 +387,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -387,7 +387,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
#if 1 // B Tile Permute
#if 0
// not pad N or K // not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
...@@ -396,18 +397,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -396,18 +397,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else #else
const index_t N1 = NPerBlock; constexpr index_t N1 = NPerBlock;
const index_t N0 = N / N1; constexpr index_t BK01 = KPerBlock / BK1Value;
const auto b_grid_desc_n0_bk0_n1_bk1 = const index_t BK00 = BK0 / BK01;
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value)); const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk00_n1_bk01_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK00, N1, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1, b_grid_desc_n0_bk00_n1_bk01_bk1,
make_tuple(make_pass_through_transform(BK0), make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_merge_transform(make_tuple(N0, N1)), make_merge_transform(make_tuple(N0, N1)),
make_pass_through_transform(BK1Value)), make_pass_through_transform(BK1Value)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif #endif
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
...@@ -614,7 +619,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -614,7 +619,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
#if 1 #if 0
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
#else #else
const int k0_offset = karg.KRead * NPerBlock; const int k0_offset = karg.KRead * NPerBlock;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment