Commit 80b46cae authored by mtgu0705's avatar mtgu0705
Browse files

Modify the mfma_16x16x16->mfma_32x32x8 for int8_dequant to match the kernal of...

Modify the mfma_16x16x16->mfma_32x32x8 for int8_dequant to match the kernal of gemm_xdl_fp16_pk_i4_v3
parent 1e741191
...@@ -56,9 +56,12 @@ using CDEElementOp = PassThrough; ...@@ -56,9 +56,12 @@ using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr ck::index_t Scale_Block_M = 128; static constexpr ck::index_t Scale_Block_N = 1;
static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 64;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t KPerBlock = 64;
static constexpr bool PermuteB = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
// clang-format off // clang-format off
...@@ -66,20 +69,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -66,20 +69,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_N, Scale_Block_K, 256, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128, 128, KPerBlock,
// 16, 16, // 16, 16,
8, 8, 8, 8,
16, 16, 32, 32,
4, 4, 2, 2,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, false, PermuteB>;
[[maybe_unused]] static static int KPerBlock = 128; // need to be aligned to the KPerBlock set in the device kernel above.
// clang-format on // clang-format on
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -210,7 +212,7 @@ int main(int argc, char* argv[]) ...@@ -210,7 +212,7 @@ int main(int argc, char* argv[])
{ {
for(int jj = 0; jj < K1; jj++) for(int jj = 0; jj < K1; jj++)
{ {
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); b0_k_n_permute(j * N * K1 + i * K1 + jj) = b0_k_n(i * K + (j * K1 + jj));
} }
} }
} }
...@@ -221,7 +223,7 @@ int main(int argc, char* argv[]) ...@@ -221,7 +223,7 @@ int main(int argc, char* argv[])
{ {
for(int j = 0; j < K; j++) for(int j = 0; j < K; j++)
{ {
b_k_n_permute(i * K + j) = b_k_n(i * K + j); b0_k_n_permute(i * K + j) = b0_k_n(i * K + j);
} }
} }
} }
...@@ -235,7 +237,7 @@ int main(int argc, char* argv[]) ...@@ -235,7 +237,7 @@ int main(int argc, char* argv[])
for(int k = 0; k < 4; k++) for(int k = 0; k < 4; k++)
{ {
int i4x2 = b_k_n_permute(j + k * 2, i); int i4x2 = b0_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
} }
...@@ -246,7 +248,7 @@ int main(int argc, char* argv[]) ...@@ -246,7 +248,7 @@ int main(int argc, char* argv[])
int lo = input[0]; int lo = input[0];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2; b0_k_n_permute(j + 0, i) = i4x2;
} }
{ {
...@@ -254,7 +256,7 @@ int main(int argc, char* argv[]) ...@@ -254,7 +256,7 @@ int main(int argc, char* argv[])
int lo = input[4]; int lo = input[4];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2; b0_k_n_permute(j + 2, i) = i4x2;
} }
{ {
...@@ -262,7 +264,7 @@ int main(int argc, char* argv[]) ...@@ -262,7 +264,7 @@ int main(int argc, char* argv[])
int lo = input[1]; int lo = input[1];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2; b0_k_n_permute(j + 4, i) = i4x2;
} }
{ {
...@@ -270,7 +272,7 @@ int main(int argc, char* argv[]) ...@@ -270,7 +272,7 @@ int main(int argc, char* argv[])
int lo = input[5]; int lo = input[5];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2; b0_k_n_permute(j + 6, i) = i4x2;
} }
} }
} }
...@@ -341,11 +343,21 @@ int main(int argc, char* argv[]) ...@@ -341,11 +343,21 @@ int main(int argc, char* argv[])
} }
} }
float v_b = 0;
for(int n = 0; n < N; n++) for(int n = 0; n < N; n++)
{ {
for(int k = 0; k < K; k++) for(int k = 0; k < K; k++)
{ {
b_k_n(k, n) = ck::type_convert<float>(quant_b0_k_n(k, n)) * pk_i4_t i4x2 = b0_k_n(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<float>(i4);
b_k_n(k, n) = ck::type_convert<float>(v_b) *
ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
} }
} }
......
...@@ -73,8 +73,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -73,8 +73,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
128, 128, 64, 128, 128, 64,
// 16, 16, // 16, 16,
8, 8, 8, 8,
16, 16, 32, 32,
4, 4, 2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
......
...@@ -71,6 +71,8 @@ template <typename ALayout, ...@@ -71,6 +71,8 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
......
...@@ -158,6 +158,20 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -158,6 +158,20 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{ {
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
...@@ -1388,6 +1402,12 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1388,6 +1402,12 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
auto b_thread_offset = auto b_thread_offset =
get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl;
// __syncthreads();
// if(blockIdx.x==0)
// {
// printf("ThreadIdx: %d, b_thread_offset: %d\n", get_thread_local_1d_id(), b_thread_offset);
// }
auto b_scale_thread_copy = auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType, ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType, BScaleType,
......
...@@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -74,6 +74,17 @@ struct ReferenceGemm : public device::BaseOperator
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
} }
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.a_m_k_(m, k);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else else
{ {
arg.a_element_op_(v_a, arg.a_m_k_(m, k)); arg.a_element_op_(v_a, arg.a_m_k_(m, k));
...@@ -84,6 +95,17 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -84,6 +95,17 @@ struct ReferenceGemm : public device::BaseOperator
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
} }
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else else
{ {
arg.b_element_op_(v_b, arg.b_k_n_(k, n)); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment