Commit a7f24bfc authored by mtgu0705's avatar mtgu0705
Browse files

Add unit test for int4 weight only

parent 8b83b087
......@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp)
add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
......
......@@ -22,6 +22,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t Scale_Block_N = 1;
......@@ -45,7 +46,7 @@ using DeviceGemmV2Instance =
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, false, PermuteB>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>;
// clang-format on
......@@ -82,21 +83,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is zero, return a default packed stride
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return col;
return static_cast<std::size_t>(col);
}
else
{
return row;
return static_cast<std::size_t>(row);
}
}
else
return stride;
return static_cast<std::size_t>(stride);
};
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
......@@ -164,11 +165,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
printf("b_k_n element space size: %zu, b_k_n device size: %lu, BDataType size: %lu\n",
b_k_n_permute.mDesc.GetElementSpaceSize(),
sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize(),
sizeof(BDataType));
// weight permute
if constexpr(PermuteB)
{
......@@ -199,7 +195,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
// vector pk_i4x4 permute
#if 1
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
......@@ -208,7 +203,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
......@@ -247,7 +242,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}
}
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
......@@ -287,9 +281,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
return true;
}
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
printf("workspace_size: %zu\n", workspace_size);
bool pass = true;
if(config.do_verification)
{
......@@ -300,12 +291,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_k_n(k, n);
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = (i4x2.data >> 4) & 0xf;
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
......@@ -331,65 +322,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#if 0
std::cout << "a_m_k: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < K; j++)
{
std::cout << ck::type_convert<float>(a_m_k(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "b_k_n: " << std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
ck::pk_i4_t i4x2 = b_k_n(j, i);
int8_t i4 = 0;
if( j % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
std::cout << ck::type_convert<float>(i4) << ",";
}
std::cout << std::endl;
}
std::cout<<"scale_b1_k_n: "<<std::endl;
for(int i = 0; i < N; i++)
{
for(int j = 0; j < (K + Scale_Block_K - 1) / Scale_Block_K; j++)
{
std::cout << ck::type_convert<float>(b1_k_n(j,i)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
#endif
}
if(config.time_kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment