Commit d1894bdb authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent b2d5cf8a
...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -35,16 +35,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -35,16 +35,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp, BElementOp,
CElementOp, CElementOp,
GemmDefault, GemmDefault,
2, // Prefetch stage 1, // Prefetch stage
128, // BlockSize 128, // BlockSize
128, // MPerBlock 64, // MPerBlock
64, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
8, // K1 8, // K1
16, // MPerWmma 16, // MPerWmma
16, // NPerWmma 16, // NPerWmma
4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
S<4, 32, 1>, S<4, 32, 1>,
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -21,7 +21,7 @@ using QuantDataType = int8_t; ...@@ -21,7 +21,7 @@ using QuantDataType = int8_t;
using BDataType = uint8_t; using BDataType = uint8_t;
using ScaleDataType = ck::half_t; using ScaleDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t; using CShuffleDataType = float;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
......
...@@ -404,6 +404,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -404,6 +404,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
// static constexpr ck::half_t fp16_subtract = -1152;
// Output.template AsType<ck::half_t>()(Number<0>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<1>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<2>{}) += fp16_subtract;
// Output.template AsType<ck::half_t>()(Number<3>{}) += fp16_subtract;
// inline assembly get very poor performance as no chance to global scheduling
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n"
: "=v"(half_2[0]) : "=v"(half_2[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment