Commit bf545630 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent f83a2f38
...@@ -639,6 +639,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -639,6 +639,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t GetKPerBlock() override { return KPerBlock; } index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteA() override { return PermuteA; }
bool GetPermuteB() override { return PermuteB; } bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
......
...@@ -22,18 +22,22 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -22,18 +22,22 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Extract the two int4 at low bit and create two fp16 number.
int lo = amd_assembly_and_or_b32(q, LO, EX); int lo = amd_assembly_and_or_b32(q, LO, EX);
// Extract the two int4 at hight bit and create two fp16 number.
int hi = amd_assembly_and_or_b32(q, HI, EX); int hi = amd_assembly_and_or_b32(q, HI, EX);
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // 1/16 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; //-79 const int ADD = 0xd480d480; // half2 {-72, -72}
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
res.template AsType<half2_t>()(Number<0>{}) = res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD)); bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment