Commit a6ccd2ec authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed int4 to bhalf_t conversion

parent d642ce41
...@@ -65,7 +65,7 @@ using DeviceGemmV2Instance = ...@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
#endif #endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, CDataType, CDataType, false, PermuteB>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
} }
......
...@@ -51,10 +51,11 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -51,10 +51,11 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
} }
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s) __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{ {
uint32_t q = bit_cast<uint16_t>(i4s); uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
uint32_t i8s = (q & 0xf) | (q & 0xf0 << 4) | (q & 0xf00 << 8) | (q & 0xf000 << 12); //uint32_t i8s = q & 0xf0f0f0f;
static constexpr uint32_t fp32_base = 0x4B000000; static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4]; float fp32_intermediates[4];
...@@ -72,8 +73,8 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s) ...@@ -72,8 +73,8 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s)
fp32_intermediates[3] -= 8388616.f; fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res; vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<0>{}) = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<1>{}) = __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[2], 0x7632); res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}]; return res.template AsType<bhalf4_t>()[Number<0>{}];
} }
...@@ -133,8 +134,9 @@ struct PassThroughPack8 ...@@ -133,8 +134,9 @@ struct PassThroughPack8
#if 1 #if 1
vector_type<bhalf_t, 8> result; vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x)); result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16); result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x));
y = result.template AsType<bhalf8_t>()[Number<0>{}]; y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else #else
......
...@@ -177,6 +177,11 @@ void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( ...@@ -177,6 +177,11 @@ void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -827,6 +832,16 @@ struct DeviceOperationInstanceFactory< ...@@ -827,6 +832,16 @@ struct DeviceOperationInstanceFactory<
} }
} }
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -98,6 +98,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES ...@@ -98,6 +98,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
......
...@@ -175,6 +175,8 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -175,6 +175,8 @@ bool profile_gemm_universal_impl(int do_verification,
} }
} }
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
...@@ -224,6 +226,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -224,6 +226,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
} }
} }
}
else else
{ {
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
......
...@@ -28,6 +28,7 @@ enum struct GemmDataType ...@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6 F16_F16_F16_F8, // 6
F8_F8_BF16, // 7 F8_F8_BF16, // 7
F16_I4_F16, // 8 F16_I4_F16, // 8
BF16_I4_BF16, // 9
}; };
#define OP_NAME "gemm_universal" #define OP_NAME "gemm_universal"
...@@ -40,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -40,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, " "f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4)\n"); "comp f8; 8: f16@i4; 9: bf16@i4\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
...@@ -193,6 +194,10 @@ int profile_gemm_universal(int argc, char* argv[]) ...@@ -193,6 +194,10 @@ int profile_gemm_universal(int argc, char* argv[])
{ {
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
} }
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment