Commit 66873f3a authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent 23bdf72a
...@@ -158,7 +158,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -158,7 +158,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
...@@ -190,57 +190,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -190,57 +190,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
} }
#if 0
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace; DeviceMem workspace;
......
...@@ -54,7 +54,7 @@ using DeviceGemmV2Instance = ...@@ -54,7 +54,7 @@ using DeviceGemmV2Instance =
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
#endif #endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
// clang-format on // clang-format on
...@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -147,7 +147,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
...@@ -179,6 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -179,6 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
} }
#if 1
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
{ {
...@@ -227,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -227,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
} }
} }
#endif
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
......
...@@ -39,14 +39,12 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -39,14 +39,12 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{ {
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
int x_l = (x_u8 & 0x0f);
int x_h = (x_u8 & 0xf0) << 12;
const int EX = 0x64006400; const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; //-8
int lo = (x_l | x_h) | EX; int lo = i4s | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
} }
...@@ -84,8 +82,8 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q) ...@@ -84,8 +82,8 @@ __host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{ {
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8; float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8; float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
vector_type<bhalf_t, 2> res; vector_type<bhalf_t, 2> res;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment