Commit e053e947 authored by Jing Zhang's avatar Jing Zhang
Browse files

weight permute

parent 82bb8dde
...@@ -134,6 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -134,6 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method) switch(config.init_method)
{ {
...@@ -169,8 +170,80 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -169,8 +170,80 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// for(int k = 1; k <= 4; k++)
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
#if 0
ck::pk_i4_t i4s[4];
i4s[0] = 0xa8;
i4s[1] = 0xec;
i4s[2] = 0xb9;
i4s[3] = 0xfd;
ck::vector_type<ck::half_t, 8> result;
result.template AsType<ck::half4_t>()(ck::Number<0>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s));
result.template AsType<ck::half4_t>()(ck::Number<1>{}) = ck::pki4_to_half4(ck::bit_cast<int>(i4s) >> 8);
printf("%f %f %f %f %f %f %f %f\n",
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<0>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<1>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<2>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<3>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<4>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<5>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<6>{}]),
ck::type_convert<float>(result.template AsType<ck::half_t>()[ck::Number<7>{}])
);
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace; DeviceMem workspace;
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
......
...@@ -10,10 +10,8 @@ ...@@ -10,10 +10,8 @@
#include "ck/utility/amd_inline_asm.hpp" #include "ck/utility/amd_inline_asm.hpp"
namespace ck { namespace ck {
namespace tensor_operation {
namespace element_wise {
__device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
...@@ -40,7 +38,7 @@ __device__ inline half4_t pki4_to_half4(int q) ...@@ -40,7 +38,7 @@ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
__device__ inline half2_t pki4_to_half2(pk_i4_t q) __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{ {
#if 0 #if 0
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
...@@ -67,6 +65,9 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -67,6 +65,9 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#endif #endif
} }
namespace tensor_operation {
namespace element_wise {
struct PassThroughPack8 struct PassThroughPack8
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -396,8 +396,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -396,8 +396,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else #else
const index_t N0 = N / NPerBlock;
const index_t N1 = NPerBlock; const index_t N1 = NPerBlock;
const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk0_n1_bk1 = const auto b_grid_desc_n0_bk0_n1_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value)); make_naive_tensor_descriptor_packed(make_tuple(N0, BK0, N1, BK1Value));
...@@ -614,7 +614,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -614,7 +614,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
#if 1
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
#else
const int k0_offset = karg.KRead * NPerBlock;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
#endif
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
......
...@@ -78,7 +78,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -78,7 +78,7 @@ struct ReferenceGemm : public device::BaseOperator
{ {
pk_i4_t i4x2 = arg.a_m_k_(m, k); pk_i4_t i4x2 = arg.a_m_k_(m, k);
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 0) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
i4 = (i4x2 >> 4) & 0xf; i4 = (i4x2 >> 4) & 0xf;
...@@ -99,7 +99,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -99,7 +99,7 @@ struct ReferenceGemm : public device::BaseOperator
{ {
pk_i4_t i4x2 = arg.b_k_n_(k, n); pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 0) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
i4 = (i4x2 >> 4) & 0xf; i4 = (i4x2 >> 4) & 0xf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment