Commit 80b46cae authored by mtgu0705's avatar mtgu0705
Browse files

Modify the mfma_16x16x16->mfma_32x32x8 for int8_dequant to match the kernal of...

Modify the mfma_16x16x16->mfma_32x32x8 for int8_dequant to match the kernal of gemm_xdl_fp16_pk_i4_v3
parent 1e741191
......@@ -56,9 +56,12 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_K = 64;
static constexpr ck::index_t KPerBlock = 64;
static constexpr bool PermuteB = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
// clang-format off
......@@ -66,20 +69,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_N, Scale_Block_K,
128, 128, 128,
128, 128, KPerBlock,
// 16, 16,
8, 8,
16, 16,
4, 4,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, false, PermuteB>;
[[maybe_unused]] static static int KPerBlock = 128; // need to be aligned to the KPerBlock set in the device kernel above.
// clang-format on
int main(int argc, char* argv[])
......@@ -210,7 +212,7 @@ int main(int argc, char* argv[])
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
b0_k_n_permute(j * N * K1 + i * K1 + jj) = b0_k_n(i * K + (j * K1 + jj));
}
}
}
......@@ -221,7 +223,7 @@ int main(int argc, char* argv[])
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
b0_k_n_permute(i * K + j) = b0_k_n(i * K + j);
}
}
}
......@@ -235,7 +237,7 @@ int main(int argc, char* argv[])
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
int i4x2 = b0_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
......@@ -246,7 +248,7 @@ int main(int argc, char* argv[])
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
b0_k_n_permute(j + 0, i) = i4x2;
}
{
......@@ -254,7 +256,7 @@ int main(int argc, char* argv[])
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
b0_k_n_permute(j + 2, i) = i4x2;
}
{
......@@ -262,7 +264,7 @@ int main(int argc, char* argv[])
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
b0_k_n_permute(j + 4, i) = i4x2;
}
{
......@@ -270,7 +272,7 @@ int main(int argc, char* argv[])
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
b0_k_n_permute(j + 6, i) = i4x2;
}
}
}
......@@ -341,11 +343,21 @@ int main(int argc, char* argv[])
}
}
float v_b = 0;
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
b_k_n(k, n) = ck::type_convert<float>(quant_b0_k_n(k, n)) *
pk_i4_t i4x2 = b0_k_n(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<float>(i4);
b_k_n(k, n) = ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
}
}
......
......@@ -73,8 +73,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
128, 128, 64,
// 16, 16,
8, 8,
16, 16,
4, 4,
32, 32,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
......
......@@ -71,6 +71,8 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
......
......@@ -158,6 +158,20 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
......@@ -1388,6 +1402,12 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
auto b_thread_offset =
get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl;
// __syncthreads();
// if(blockIdx.x==0)
// {
// printf("ThreadIdx: %d, b_thread_offset: %d\n", get_thread_local_1d_id(), b_thread_offset);
// }
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
......
......@@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.a_m_k_(m, k);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
......@@ -84,6 +95,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment