Commit a6ccd2ec authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed int4 to bhalf_t conversion

parent d642ce41
......@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, CDataType, CDataType, false, PermuteB>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>;
// clang-format on
......@@ -146,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
......
......@@ -51,10 +51,11 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s)
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{
uint32_t q = bit_cast<uint16_t>(i4s);
uint32_t i8s = (q & 0xf) | (q & 0xf0 << 4) | (q & 0xf00 << 8) | (q & 0xf000 << 12);
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
//uint32_t i8s = q & 0xf0f0f0f;
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
......@@ -72,8 +73,8 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(pk_i4x2_t i4s)
fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<0>{}) = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
res.template AsType<bhalf2_t>()(Number<1>{}) = __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[2], 0x7632);
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
......@@ -133,8 +134,9 @@ struct PassThroughPack8
#if 1
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x));
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
......
......@@ -40,11 +40,11 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......
......@@ -177,6 +177,11 @@ void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -827,6 +832,16 @@ struct DeviceOperationInstanceFactory<
}
}
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
}
}
return op_ptrs;
}
};
......
......@@ -98,6 +98,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
......
......@@ -175,51 +175,54 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{
for(int j = 0; j < K; j += 8)
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
for(int j = 0; j < K; j += 8)
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
}
......
......@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
F16_I4_F16, // 8
BF16_I4_BF16, // 9
};
#define OP_NAME "gemm_universal"
......@@ -40,7 +41,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: f16@i4)\n");
"comp f8; 8: f16@i4; 9: bf16@i4\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -193,6 +194,10 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_I4_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, I4{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment