Commit 9de3a085 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent a6ccd2ec
......@@ -69,7 +69,7 @@ using DeviceGemmV2Instance =
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
......
......@@ -58,7 +58,7 @@ using DeviceGemmV2Instance =
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
......
......@@ -11,7 +11,7 @@
namespace ck {
//https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q)
{
const int LO = 0x000f000f;
......@@ -54,7 +54,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
//uint32_t i8s = q & 0xf0f0f0f;
// uint32_t i8s = q & 0xf0f0f0f;
static constexpr uint32_t fp32_base = 0x4B000000;
......@@ -73,8 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
......@@ -94,7 +96,6 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation {
namespace element_wise {
......@@ -137,7 +138,6 @@ struct PassThroughPack8
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x));
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
vector_type<bhalf_t, 8> dst;
......
......@@ -838,7 +838,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment