Commit c13366af authored by Jing Zhang's avatar Jing Zhang
Browse files

add fast pki4 to half conversion

parent 24e18ae8
...@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -139,7 +139,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{ {
case 0: case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{0x11}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{0x99});
break; break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
......
...@@ -7,11 +7,57 @@ ...@@ -7,11 +7,57 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
#include "ck/utility/amd_inline_asm.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
//int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
//int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = (q & LO) | EX;
int hi = (q & HI) | EX;
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0xE408E408; //-8
const int MUL = 0x2c002c00; //1/16
const int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi),
bit_cast<half2_t>(MUL),
bit_cast<half2_t>(ADD));
return res.template AsType<half4_t>()[Number<0>{}];
}
struct PassThroughPack8
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
}
constexpr const static bool is_pack8_invocable = true;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -387,6 +387,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -387,6 +387,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
#if 1
// not pad N or K // not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw, b_grid_desc_nraw_kraw,
...@@ -394,6 +395,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -394,6 +395,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else
const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock;
const auto b_grid_desc_n0_bk0_n1_bk1 = make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk0_n1_bk1,
make_tuple(make_pass_through_transform(BK0),
make_merge_transform(make_tuple(N0, N1)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
......
...@@ -1150,12 +1150,14 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1150,12 +1150,14 @@ struct ThreadwiseTensorSliceTransfer_v4
// DstData) // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = PackedSize; constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type; using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, 1>::type; using src_v_t = typename vector_type_maker_t<SrcData, 4>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}( ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i), dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]); src_tmp_vector.template AsType<src_v_t>()[i]);
}); });
......
...@@ -11,6 +11,21 @@ ...@@ -11,6 +11,21 @@
namespace ck { namespace ck {
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) {
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
return d;
}
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) {
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
...@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> ...@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( //reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); //reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -1054,12 +1054,14 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type; ...@@ -1054,12 +1054,14 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8 // i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type; //using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type; //using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type; //using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type; //using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type; //using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type; //using uint8x64_t = typename vector_type<uint8_t, 64>::type;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment