Commit 9de3a085 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent a6ccd2ec
...@@ -69,7 +69,7 @@ using DeviceGemmV2Instance = ...@@ -69,7 +69,7 @@ using DeviceGemmV2Instance =
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
......
...@@ -58,7 +58,7 @@ using DeviceGemmV2Instance = ...@@ -58,7 +58,7 @@ using DeviceGemmV2Instance =
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace ck { namespace ck {
//https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
const int LO = 0x000f000f; const int LO = 0x000f000f;
...@@ -54,7 +54,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -54,7 +54,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
{ {
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
//uint32_t i8s = q & 0xf0f0f0f; // uint32_t i8s = q & 0xf0f0f0f;
static constexpr uint32_t fp32_base = 0x4B000000; static constexpr uint32_t fp32_base = 0x4B000000;
...@@ -73,8 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q) ...@@ -73,8 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
fp32_intermediates[3] -= 8388616.f; fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res; vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632)); res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632)); __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}]; return res.template AsType<bhalf4_t>()[Number<0>{}];
} }
...@@ -94,7 +96,6 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q) ...@@ -94,7 +96,6 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
return res.template AsType<bhalf2_t>()[Number<0>{}]; return res.template AsType<bhalf2_t>()[Number<0>{}];
} }
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
...@@ -137,7 +138,6 @@ struct PassThroughPack8 ...@@ -137,7 +138,6 @@ struct PassThroughPack8
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16); result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x)); result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x));
y = result.template AsType<bhalf8_t>()[Number<0>{}]; y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else #else
vector_type<bhalf_t, 8> dst; vector_type<bhalf_t, 8> dst;
......
...@@ -838,7 +838,8 @@ struct DeviceOperationInstanceFactory< ...@@ -838,7 +838,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{ {
add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs); add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment