Commit e053e947 authored by Jing Zhang's avatar Jing Zhang
Browse files

weight permute

parent 82bb8dde
......@@ -134,6 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
......@@ -169,8 +170,80 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// for(int k = 1; k <= 4; k++)
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
#if 0
ck::pk_i4_t i4s[4];
i4s[0] = 0xa8;
i4s[1] = 0xec;
i4s[2] = 0xb9;
i4s[3] = 0xfd;
ck::vector_type<ck::half_t, 8> result;
result.template AsType<ck::half4_t>()(ck::Number<0>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s));
result.template AsType<ck::half4_t>()(ck::Number<1>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s) >> 8);
printf("%f %f %f %f %f %f %f %f\n",
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<0>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<1>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<2>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<3>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<4>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<5>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<6>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<7>{}])
);
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
......
......@@ -10,10 +10,8 @@
#include "ck/utility/amd_inline_asm.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q)
__host__ __device__ inline half4_t pki4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
......@@ -40,7 +38,7 @@ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}];
}
__device__ inline half2_t pki4_to_half2(pk_i4_t q)
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
......@@ -58,7 +56,7 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12;
const int EX = 0x64006400;
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = (x_l | x_h) | EX;
......@@ -67,6 +65,9 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#endif
}
namespace tensor_operation {
namespace element_wise {
struct PassThroughPack8
{
template <typename Y, typename X>
......
......@@ -396,8 +396,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else
const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock;
const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
......@@ -614,7 +614,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
#if 1
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
#else
const int k0_offset = karg.KRead * NPerBlock;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
#endif
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
......
......@@ -78,7 +78,7 @@ struct ReferenceGemm : public device::BaseOperator
{
pk_i4_t i4x2 = arg.a_m_k_(m, k);
int8_t i4 = 0;
if(k % 2 == 0)
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
......@@ -99,7 +99,7 @@ struct ReferenceGemm : public device::BaseOperator
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 0)
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment